Skip to content

vllm.v1.attention.backends.triton_attn_diffkv

Triton attention backend with different K/V head dimensions (DiffKV).

The KV cache layout is identical to FlashAttentionDiffKVBackend — K and V are packed along the last dim:

[num_blocks, block_size, num_kv_heads, head_size_qk + head_size_v]

so existing helpers (triton_reshape_and_cache_flash_diffkv) are reused.

Classes:

TritonAttentionDiffKVImpl

Bases: TritonAttentionImpl

Triton attention impl for the DiffKV packed KV cache layout.

Methods:

Source code in vllm/v1/attention/backends/triton_attn_diffkv.py
class TritonAttentionDiffKVImpl(TritonAttentionImpl):
    """Triton attention impl for the DiffKV packed KV cache layout."""

    def __init__(self, *args, **kwargs) -> None:
        super().__init__(*args, **kwargs)
        if is_quantized_kv_cache(self.kv_cache_dtype):
            raise NotImplementedError(
                "TritonAttentionDiffKVBackend does not yet support quantized "
                f"KV cache (got kv_cache_dtype={self.kv_cache_dtype!r})."
            )
        if self._is_per_token_head_quant:
            raise NotImplementedError(
                "TritonAttentionDiffKVBackend does not support per-token-head "
                "quantization."
            )
        if self.chunk_lookback > -1:
            raise NotImplementedError(
                "TritonAttentionDiffKVBackend does not support chunked "
                "attention with lookback."
            )

    def do_kv_cache_update(
        self,
        layer: AttentionLayer,
        key: torch.Tensor,
        value: torch.Tensor,
        kv_cache: torch.Tensor,
        slot_mapping: torch.Tensor,
    ) -> None:
        # Cache is packed [..., head_size_qk + head_size_v]; the diffkv
        # reshape kernel writes K to [..., :head_size_qk] and V to
        # [..., head_size_qk:hqk+hv].
        triton_reshape_and_cache_flash_diffkv(
            key,
            value,
            kv_cache,
            slot_mapping,
            self.kv_cache_dtype,
            layer._k_scale,
            layer._v_scale,
        )

    def fused_rope_kvcache_supported(self):
        # The fused rope+cache path assumes the standard 2-tensor layout.
        return False

    def forward(
        self,
        layer: torch.nn.Module,
        query: torch.Tensor,
        key: torch.Tensor,
        value: torch.Tensor,
        kv_cache: torch.Tensor,
        attn_metadata: TritonAttentionMetadata,
        output: torch.Tensor,
        output_scale: torch.Tensor | None = None,
        output_block_scale: torch.Tensor | None = None,
    ) -> torch.Tensor:
        """Forward pass.

        Shapes:
            query:    [num_tokens, num_heads, head_size_qk]
            key:      [num_tokens, num_kv_heads, head_size_qk]
            value:    [num_tokens, num_kv_heads, head_size_v]
            kv_cache: [num_blocks, block_size, num_kv_heads,
                       head_size_qk + head_size_v]
            output:   [num_tokens, num_heads, head_size_v]
        """
        if output_scale is not None or output_block_scale is not None:
            raise NotImplementedError(
                "fused output quantization is not supported for "
                "TritonAttentionDiffKVImpl"
            )

        if attn_metadata is None:
            return output.fill_(0)

        assert attn_metadata.use_cascade is False, (
            "Cascade attention not supported for TritonAttentionDiffKVImpl"
        )

        num_actual_tokens = attn_metadata.num_actual_tokens
        head_size_qk = self.head_size
        head_size_v = TritonAttentionDiffKVBackend.head_size_v

        # Slice the packed cache into K / V views.  Strides on dims 0/1/2
        # match the original cache; dim 3 stays contiguous (stride 1).
        key_cache = kv_cache[..., :head_size_qk]
        value_cache = kv_cache[..., head_size_qk : head_size_qk + head_size_v]

        unified_attention_diffkv(
            q=query[:num_actual_tokens],
            k=key_cache,
            v=value_cache,
            out=output[:num_actual_tokens],
            cu_seqlens_q=attn_metadata.query_start_loc,
            seqused_k=attn_metadata.seq_lens,
            softmax_scale=self.scale,
            causal=True,
            alibi_slopes=self.alibi_slopes,
            use_alibi_sqrt=self.use_alibi_sqrt,
            window_size=self.sliding_window,
            block_table=attn_metadata.block_table,
            softcap=self.logits_soft_cap,
            sinks=self.sinks,
            max_seqlen_q=attn_metadata.max_query_len,
            seq_threshold_3D=attn_metadata.seq_threshold_3D,
            num_par_softmax_segments=attn_metadata.num_par_softmax_segments,
            softmax_segm_output=attn_metadata.softmax_segm_output,
            softmax_segm_max=attn_metadata.softmax_segm_max,
            softmax_segm_expsum=attn_metadata.softmax_segm_expsum,
        )
        return output

forward(layer, query, key, value, kv_cache, attn_metadata, output, output_scale=None, output_block_scale=None)

Forward pass.

Shapes

query: [num_tokens, num_heads, head_size_qk] key: [num_tokens, num_kv_heads, head_size_qk] value: [num_tokens, num_kv_heads, head_size_v] kv_cache: [num_blocks, block_size, num_kv_heads, head_size_qk + head_size_v] output: [num_tokens, num_heads, head_size_v]

Source code in vllm/v1/attention/backends/triton_attn_diffkv.py
def forward(
    self,
    layer: torch.nn.Module,
    query: torch.Tensor,
    key: torch.Tensor,
    value: torch.Tensor,
    kv_cache: torch.Tensor,
    attn_metadata: TritonAttentionMetadata,
    output: torch.Tensor,
    output_scale: torch.Tensor | None = None,
    output_block_scale: torch.Tensor | None = None,
) -> torch.Tensor:
    """Forward pass.

    Shapes:
        query:    [num_tokens, num_heads, head_size_qk]
        key:      [num_tokens, num_kv_heads, head_size_qk]
        value:    [num_tokens, num_kv_heads, head_size_v]
        kv_cache: [num_blocks, block_size, num_kv_heads,
                   head_size_qk + head_size_v]
        output:   [num_tokens, num_heads, head_size_v]
    """
    if output_scale is not None or output_block_scale is not None:
        raise NotImplementedError(
            "fused output quantization is not supported for "
            "TritonAttentionDiffKVImpl"
        )

    if attn_metadata is None:
        return output.fill_(0)

    assert attn_metadata.use_cascade is False, (
        "Cascade attention not supported for TritonAttentionDiffKVImpl"
    )

    num_actual_tokens = attn_metadata.num_actual_tokens
    head_size_qk = self.head_size
    head_size_v = TritonAttentionDiffKVBackend.head_size_v

    # Slice the packed cache into K / V views.  Strides on dims 0/1/2
    # match the original cache; dim 3 stays contiguous (stride 1).
    key_cache = kv_cache[..., :head_size_qk]
    value_cache = kv_cache[..., head_size_qk : head_size_qk + head_size_v]

    unified_attention_diffkv(
        q=query[:num_actual_tokens],
        k=key_cache,
        v=value_cache,
        out=output[:num_actual_tokens],
        cu_seqlens_q=attn_metadata.query_start_loc,
        seqused_k=attn_metadata.seq_lens,
        softmax_scale=self.scale,
        causal=True,
        alibi_slopes=self.alibi_slopes,
        use_alibi_sqrt=self.use_alibi_sqrt,
        window_size=self.sliding_window,
        block_table=attn_metadata.block_table,
        softcap=self.logits_soft_cap,
        sinks=self.sinks,
        max_seqlen_q=attn_metadata.max_query_len,
        seq_threshold_3D=attn_metadata.seq_threshold_3D,
        num_par_softmax_segments=attn_metadata.num_par_softmax_segments,
        softmax_segm_output=attn_metadata.softmax_segm_output,
        softmax_segm_max=attn_metadata.softmax_segm_max,
        softmax_segm_expsum=attn_metadata.softmax_segm_expsum,
    )
    return output

TritonAttentionDiffKVMetadataBuilder

Bases: TritonAttentionMetadataBuilder

Override the parent's softmax buffer last-dim to head_size_v.

The parent allocates softmax_segm_output with last-dim sized to next_power_of_2(head_size) (== Q/K head size). For DiffKV the accumulator and per-segment partial outputs are V-shaped, so we re-allocate with next_power_of_2(head_size_v) instead.

Source code in vllm/v1/attention/backends/triton_attn_diffkv.py
class TritonAttentionDiffKVMetadataBuilder(TritonAttentionMetadataBuilder):
    """Override the parent's softmax buffer last-dim to head_size_v.

    The parent allocates ``softmax_segm_output`` with last-dim sized to
    ``next_power_of_2(head_size)`` (== Q/K head size).  For DiffKV the
    accumulator and per-segment partial outputs are V-shaped, so we
    re-allocate with ``next_power_of_2(head_size_v)`` instead.
    """

    def __init__(
        self,
        kv_cache_spec: AttentionSpec,
        layer_names: list[str],
        vllm_config: VllmConfig,
        device: torch.device,
    ):
        super().__init__(kv_cache_spec, layer_names, vllm_config, device)

        head_size_v = TritonAttentionDiffKVBackend.head_size_v
        head_size_v_padded = next_power_of_2(head_size_v)
        self.softmax_segm_output = torch.empty(
            (
                self.seq_threshold_3D,
                self.num_heads_q,
                self.num_par_softmax_segments,
                head_size_v_padded,
            ),
            dtype=torch.float32,
            device=device,
        )