Skip to content

vllm.v1.attention.backends.mla.triton_mla

Classes:

TritonMLAMetadataBuilder

Bases: MLACommonMetadataBuilder[MLACommonMetadata]

Source code in vllm/v1/attention/backends/mla/triton_mla.py
class TritonMLAMetadataBuilder(MLACommonMetadataBuilder[MLACommonMetadata]):
    _cudagraph_support: ClassVar[AttentionCGSupport] = AttentionCGSupport.UNIFORM_BATCH

    def __init__(self, kv_cache_spec, layer_names, vllm_config, device):
        super().__init__(kv_cache_spec, layer_names, vllm_config, device)
        self._reserve_attn_logits_workspace()

    def _reserve_attn_logits_workspace(self) -> None:
        """Pre-size the shared workspace for the decode split-KV attn logits.

        Reserving at the worst case (max_model_len -> max num_kv_splits,
        max_num_seqs decode tokens) before warmup/cudagraph capture means the
        per-call ``get_simultaneous`` in ``forward_mqa`` never has to grow the
        buffer at runtime (which would raise once the workspace is locked).
        """
        if not is_workspace_manager_initialized():
            return
        # Decode reorder threshold is 1, so decode tokens <= max_num_seqs.
        B = self.vllm_config.scheduler_config.max_num_seqs
        # DCP all-gathers the query heads before forward_mqa.
        q_num_heads = self.num_heads * self.dcp_world_size
        max_splits = _compute_num_kv_splits(
            self.model_config.max_model_len,
            current_platform.num_compute_units(),
        )
        lse_dim = self.mla_dims.kv_lora_rank + 1
        current_workspace_manager().get_simultaneous(
            ((B, q_num_heads, max_splits, lse_dim), torch.float32),
        )

_reserve_attn_logits_workspace()

Pre-size the shared workspace for the decode split-KV attn logits.

Reserving at the worst case (max_model_len -> max num_kv_splits, max_num_seqs decode tokens) before warmup/cudagraph capture means the per-call get_simultaneous in forward_mqa never has to grow the buffer at runtime (which would raise once the workspace is locked).

Source code in vllm/v1/attention/backends/mla/triton_mla.py
def _reserve_attn_logits_workspace(self) -> None:
    """Pre-size the shared workspace for the decode split-KV attn logits.

    Reserving at the worst case (max_model_len -> max num_kv_splits,
    max_num_seqs decode tokens) before warmup/cudagraph capture means the
    per-call ``get_simultaneous`` in ``forward_mqa`` never has to grow the
    buffer at runtime (which would raise once the workspace is locked).
    """
    if not is_workspace_manager_initialized():
        return
    # Decode reorder threshold is 1, so decode tokens <= max_num_seqs.
    B = self.vllm_config.scheduler_config.max_num_seqs
    # DCP all-gathers the query heads before forward_mqa.
    q_num_heads = self.num_heads * self.dcp_world_size
    max_splits = _compute_num_kv_splits(
        self.model_config.max_model_len,
        current_platform.num_compute_units(),
    )
    lse_dim = self.mla_dims.kv_lora_rank + 1
    current_workspace_manager().get_simultaneous(
        ((B, q_num_heads, max_splits, lse_dim), torch.float32),
    )