Skip to content

vllm.v1.attention.ops.triton_unified_attention_diffkv

Triton unified attention with different K/V head dimensions (DiffKV).

This is a slimmed fork of triton_unified_attention.py for models like MiMo-V2.5 where the V tensor's head dimension differs from K's. The KV cache is the same packed layout used by FlashAttentionDiffKVBackend:

kv_cache: [num_blocks, block_size, num_kv_heads, head_size_qk + head_size_v]

We slice key_cache = kv_cache[..., :head_size_qk] and value_cache = kv_cache[..., head_size_qk:] on the host, so the kernel takes two cache pointers but with two distinct head sizes.

Both 2D and 3D launches are supported
  • 2D: one program per (q-block, kv-head); tile-loop walks the full KV sequence; final output written directly. Used for prefill and large decode batches.
  • 3D: one program per (q-block, kv-head, segm); each program covers a KV slice and writes per-segment partials (max/expsum/output). A follow-up kernel_reduce_segments_diffkv combines them. Selected for decode-only batches whose 2D grid would under-fill the GPU.

Functions:

kernel_reduce_segments_diffkv(output_ptr, segm_output_ptr, segm_max_ptr, segm_expsum_ptr, seq_lens_ptr, num_seqs, num_query_heads, output_stride_0, output_stride_1, TILE_SIZE, HEAD_SIZE_V, HEAD_SIZE_V_PADDED, query_start_len_ptr, BLOCK_Q, NUM_SEGMENTS_PER_SEQ)

Combine per-segment partials into the final softmax output.

Mirrors reduce_segments from triton_unified_attention.py but indexes V's head size (HEAD_SIZE_V) instead of the shared one.

Source code in vllm/v1/attention/ops/triton_unified_attention_diffkv.py
@triton.jit
def kernel_reduce_segments_diffkv(
    output_ptr,  # [num_tokens, num_query_heads, head_size_v]
    segm_output_ptr,
    # [num_tokens, num_query_heads, max_num_segments, head_size_v]
    segm_max_ptr,  # [num_tokens, num_query_heads, max_num_segments]
    segm_expsum_ptr,  # [num_tokens, num_query_heads, max_num_segments]
    seq_lens_ptr,  # [num_seqs]
    num_seqs,
    num_query_heads: tl.constexpr,
    output_stride_0: tl.int64,
    output_stride_1: tl.int64,  # == HEAD_SIZE_V
    TILE_SIZE: tl.constexpr,
    HEAD_SIZE_V: tl.constexpr,
    HEAD_SIZE_V_PADDED: tl.constexpr,
    query_start_len_ptr,  # [num_seqs+1]
    BLOCK_Q: tl.constexpr,
    NUM_SEGMENTS_PER_SEQ: tl.constexpr,
):
    """Combine per-segment partials into the final softmax output.

    Mirrors ``reduce_segments`` from triton_unified_attention.py but
    indexes V's head size (``HEAD_SIZE_V``) instead of the shared one.
    """
    query_token_idx = tl.program_id(0)
    query_head_idx = tl.program_id(1)

    seq_idx = find_seq_idx(
        query_start_len_ptr, query_token_idx, num_seqs, BLOCK_Q, False
    )
    seq_len = tl.load(seq_lens_ptr + seq_idx)

    tiles_per_segment = cdiv_fn(seq_len, NUM_SEGMENTS_PER_SEQ * TILE_SIZE)
    act_num_segments = cdiv_fn(seq_len, tiles_per_segment * TILE_SIZE)
    segm_mask = tl.arange(0, NUM_SEGMENTS_PER_SEQ) < tl.full(
        [NUM_SEGMENTS_PER_SEQ], act_num_segments, dtype=tl.int32
    )
    dim_mask = tl.where(tl.arange(0, HEAD_SIZE_V_PADDED) < HEAD_SIZE_V, 1, 0).to(
        tl.int1
    )

    segm_offset = (
        query_token_idx.to(tl.int64) * (num_query_heads * NUM_SEGMENTS_PER_SEQ)
        + query_head_idx * NUM_SEGMENTS_PER_SEQ
        + tl.arange(0, NUM_SEGMENTS_PER_SEQ)
    )
    segm_max = tl.load(segm_max_ptr + segm_offset, mask=segm_mask, other=float("-inf"))
    overall_max = tl.max(segm_max)

    segm_expsum = tl.load(segm_expsum_ptr + segm_offset, mask=segm_mask, other=0.0)
    segm_expsum = segm_expsum * tl.exp(segm_max - overall_max)
    overall_expsum = tl.sum(segm_expsum)

    segm_output_offset = (
        query_token_idx.to(tl.int64)
        * (num_query_heads * NUM_SEGMENTS_PER_SEQ * HEAD_SIZE_V_PADDED)
        + query_head_idx * (NUM_SEGMENTS_PER_SEQ * HEAD_SIZE_V_PADDED)
        + tl.arange(0, NUM_SEGMENTS_PER_SEQ)[:, None] * HEAD_SIZE_V_PADDED
        + tl.arange(0, HEAD_SIZE_V_PADDED)[None, :]
    )
    segm_output = tl.load(
        segm_output_ptr + segm_output_offset,
        mask=segm_mask[:, None] & dim_mask[None, :],
        other=0.0,
    )
    segm_output *= tl.exp(segm_max - overall_max)[:, None]
    acc_sum = tl.sum(segm_output, axis=0)
    acc = tl.where(overall_expsum == 0.0, 0.0, acc_sum / overall_expsum)

    output_offset = (
        query_token_idx * output_stride_0
        + query_head_idx * output_stride_1
        + tl.arange(0, HEAD_SIZE_V_PADDED)
    )
    tl.store(output_ptr + output_offset, acc, mask=dim_mask)