class FusedStagedWriter:
"""Applies the staged writes of several `StagedWriteTensor`s at once."""
def __init__(
self, device: torch.device, max_writes: int, max_concurrency: int | None = None
):
new_pool = partial(
UvaBufferPool, dtype=torch.int32, max_concurrency=max_concurrency
)
self.group_ids = new_pool(max_writes)
self.indices = new_pool(max_writes)
self.starts = new_pool(max_writes)
self.cu_lens = new_pool(max_writes)
self.device = device
def apply(
self,
tensors: Sequence[StagedWriteTensor],
output_ptrs: torch.Tensor,
output_strides: torch.Tensor,
) -> None:
"""Apply and clear the staged writes of `tensors` with one kernel."""
group_ids: list[int] = []
indices: list[int] = []
starts: list[int] = []
contents: list[int | float] = []
cu_lens: list[int] = []
for group_id, t in enumerate(tensors):
n = len(t._staged_write_indices)
if n == 0:
continue
group_ids.extend([group_id] * n)
indices.extend(t._staged_write_indices)
starts.extend(t._staged_write_starts)
content_base = len(contents)
contents.extend(t._staged_write_contents)
cu_lens.extend(content_base + cu_len for cu_len in t._staged_write_cu_lens)
if not group_ids:
return
group_ids_uva = self.group_ids.copy_to_uva(group_ids)
indices_uva = self.indices.copy_to_uva(indices)
starts_uva = self.starts.copy_to_uva(starts)
cu_lens_uva = self.cu_lens.copy_to_uva(cu_lens)
contents_gpu = async_tensor_h2d(contents, device=self.device, dtype=torch.int32)
_apply_write_kernel[(len(group_ids),)](
output_ptrs,
output_strides,
indices_uva,
starts_uva,
contents_gpu,
cu_lens_uva,
group_ids_uva,
BLOCK_SIZE=1024,
MULTI_GROUP=True,
)
for t in tensors:
t.clear_staged_writes()