Skip to content

vllm.v1.worker.gpu.buffer_utils

Classes:

  • FusedStagedWriter

    Applies the staged writes of several StagedWriteTensors at once.

FusedStagedWriter

Applies the staged writes of several StagedWriteTensors at once.

Methods:

  • apply

    Apply and clear the staged writes of tensors with one kernel.

Source code in vllm/v1/worker/gpu/buffer_utils.py
class FusedStagedWriter:
    """Applies the staged writes of several `StagedWriteTensor`s at once."""

    def __init__(
        self, device: torch.device, max_writes: int, max_concurrency: int | None = None
    ):
        new_pool = partial(
            UvaBufferPool, dtype=torch.int32, max_concurrency=max_concurrency
        )
        self.group_ids = new_pool(max_writes)
        self.indices = new_pool(max_writes)
        self.starts = new_pool(max_writes)
        self.cu_lens = new_pool(max_writes)
        self.device = device

    def apply(
        self,
        tensors: Sequence[StagedWriteTensor],
        output_ptrs: torch.Tensor,
        output_strides: torch.Tensor,
    ) -> None:
        """Apply and clear the staged writes of `tensors` with one kernel."""
        group_ids: list[int] = []
        indices: list[int] = []
        starts: list[int] = []
        contents: list[int | float] = []
        cu_lens: list[int] = []

        for group_id, t in enumerate(tensors):
            n = len(t._staged_write_indices)
            if n == 0:
                continue

            group_ids.extend([group_id] * n)
            indices.extend(t._staged_write_indices)
            starts.extend(t._staged_write_starts)
            content_base = len(contents)
            contents.extend(t._staged_write_contents)
            cu_lens.extend(content_base + cu_len for cu_len in t._staged_write_cu_lens)

        if not group_ids:
            return

        group_ids_uva = self.group_ids.copy_to_uva(group_ids)
        indices_uva = self.indices.copy_to_uva(indices)
        starts_uva = self.starts.copy_to_uva(starts)
        cu_lens_uva = self.cu_lens.copy_to_uva(cu_lens)
        contents_gpu = async_tensor_h2d(contents, device=self.device, dtype=torch.int32)

        _apply_write_kernel[(len(group_ids),)](
            output_ptrs,
            output_strides,
            indices_uva,
            starts_uva,
            contents_gpu,
            cu_lens_uva,
            group_ids_uva,
            BLOCK_SIZE=1024,
            MULTI_GROUP=True,
        )
        for t in tensors:
            t.clear_staged_writes()

apply(tensors, output_ptrs, output_strides)

Apply and clear the staged writes of tensors with one kernel.

Source code in vllm/v1/worker/gpu/buffer_utils.py
def apply(
    self,
    tensors: Sequence[StagedWriteTensor],
    output_ptrs: torch.Tensor,
    output_strides: torch.Tensor,
) -> None:
    """Apply and clear the staged writes of `tensors` with one kernel."""
    group_ids: list[int] = []
    indices: list[int] = []
    starts: list[int] = []
    contents: list[int | float] = []
    cu_lens: list[int] = []

    for group_id, t in enumerate(tensors):
        n = len(t._staged_write_indices)
        if n == 0:
            continue

        group_ids.extend([group_id] * n)
        indices.extend(t._staged_write_indices)
        starts.extend(t._staged_write_starts)
        content_base = len(contents)
        contents.extend(t._staged_write_contents)
        cu_lens.extend(content_base + cu_len for cu_len in t._staged_write_cu_lens)

    if not group_ids:
        return

    group_ids_uva = self.group_ids.copy_to_uva(group_ids)
    indices_uva = self.indices.copy_to_uva(indices)
    starts_uva = self.starts.copy_to_uva(starts)
    cu_lens_uva = self.cu_lens.copy_to_uva(cu_lens)
    contents_gpu = async_tensor_h2d(contents, device=self.device, dtype=torch.int32)

    _apply_write_kernel[(len(group_ids),)](
        output_ptrs,
        output_strides,
        indices_uva,
        starts_uva,
        contents_gpu,
        cu_lens_uva,
        group_ids_uva,
        BLOCK_SIZE=1024,
        MULTI_GROUP=True,
    )
    for t in tensors:
        t.clear_staged_writes()