class DmaCopyBackend:
"""cuMemcpyBatchAsync copy backend (background thread)."""
def __init__(self) -> None:
self._store_params: BatchMemcpyParams | None = None
self._load_params: BatchMemcpyParams | None = None
self._load_stream: torch.cuda.Stream | None = None
self._store_stream: torch.cuda.Stream | None = None
self._queue: queue.SimpleQueue | None = None
self._thread: threading.Thread | None = None
self._shutdown: bool = False
def init(
self,
gpu_caches: dict[str, torch.Tensor],
cpu_caches: dict[str, torch.Tensor],
device: torch.device,
load_stream: torch.cuda.Stream,
store_stream: torch.cuda.Stream,
) -> None:
self._load_stream = load_stream
self._store_stream = store_stream
# Stores read the live KV cache -> STREAM (paired with the compute-done
# wait in get_finished); loads read stable pinned host memory -> ANY.
self._store_params = build_params(
gpu_caches,
cpu_caches,
store_stream,
src_access_order=CU_MEMCPY_SRC_ACCESS_ORDER_STREAM,
)
self._load_params = build_params(
cpu_caches,
gpu_caches,
load_stream,
src_access_order=CU_MEMCPY_SRC_ACCESS_ORDER_ANY,
)
self._queue = queue.SimpleQueue()
self._thread = threading.Thread(
target=self._copy_loop,
args=(self._queue, device, load_stream, store_stream),
daemon=True,
)
self._thread.start()
def launch_copy(
self,
src_blocks: list[int],
dst_blocks: list[int],
is_store: bool,
event_idx: int,
events_list: list[tuple[int, torch.Event]],
wait_event: torch.Event | None = None,
) -> None:
params = self._store_params if is_store else self._load_params
assert params is not None and self._queue is not None
self._queue.put(
(
src_blocks,
dst_blocks,
params,
is_store,
event_idx,
events_list,
wait_event,
)
)
def shutdown(self) -> None:
if self._shutdown:
return
self._shutdown = True
if self._queue is not None:
self._queue.put(None)
if self._thread is not None:
self._thread.join(timeout=5.0)
@staticmethod
def _copy_loop(
q: queue.SimpleQueue,
device: torch.device,
load_stream: torch.cuda.Stream,
store_stream: torch.cuda.Stream,
) -> None:
current_platform.set_device(device)
while True:
item = q.get()
if item is None:
return
(
src_blocks,
dst_blocks,
params,
is_store,
event_idx,
events_list,
wait_event,
) = item
stream = store_stream if is_store else load_stream
if wait_event is not None:
stream.wait_event(wait_event)
copy_blocks(src_blocks, dst_blocks, params)
event = torch.Event()
event.record(stream)
events_list.append((event_idx, event))