Skip to content

vllm.v1.attention.ops.rocm_aiter_mla_sparse

Functions:

_decode_num_splits(num_queries, heads_blocks, avg_main_len=0.0, avg_extra_len=0.0, block_k=32)

Pick a flash-decode split count to keep the GPU busy across batch sizes.

Decode launches only num_queries * heads_blocks workgroups otherwise, which severely under-fills the device for the low-concurrency regime that dominates latency. Splitting the KV sequence adds parallelism.

We model the relative partial-kernel latency for a given split count s as waves * (1/s + mu) where waves = ceil(base * s / CU) and mu is a small per-wave overhead penalty:

  • waves / s captures the partial compute: each wave walks roughly total_tokens / s tokens and there are waves of them, so dividing by s makes more splits cheaper until they spill into extra waves.
  • mu * waves charges per-wave launch/tail overhead so we do not over-split into many mostly-idle waves (e.g. batch 224 on 256 CUs is best left at 1 split rather than 8 splits across 7 waves).

The minimiser naturally prefers split counts that pack the device into full waves (base * s near a multiple of CU) and falls back to 1 split once the batch already fills the device. Ties favour the smaller split count (less reduce work).

Finally we "snap down" the chosen split count to the smallest value that yields the same wave count and the same per-workgroup BLOCK_K iteration count. Because latency tracks iteration count (not raw token count), extra splits that do not lower the iteration count add only reduce/HBM overhead for no parallelism gain (e.g. batch 24: s8 and s10 both walk 4 extra iters in one wave, so s8 is strictly better). Snapping needs the average segment lengths, which the caller derives sync-free from the ragged index sizes.

Source code in vllm/v1/attention/ops/rocm_aiter_mla_sparse.py
def _decode_num_splits(
    num_queries: int,
    heads_blocks: int,
    avg_main_len: float = 0.0,
    avg_extra_len: float = 0.0,
    block_k: int = 32,
) -> int:
    """Pick a flash-decode split count to keep the GPU busy across batch sizes.

    Decode launches only ``num_queries * heads_blocks`` workgroups otherwise,
    which severely under-fills the device for the low-concurrency regime that
    dominates latency. Splitting the KV sequence adds parallelism.

    We model the relative partial-kernel latency for a given split count ``s``
    as ``waves * (1/s + mu)`` where ``waves = ceil(base * s / CU)`` and ``mu``
    is a small per-wave overhead penalty:

      - ``waves / s`` captures the partial compute: each wave walks roughly
        ``total_tokens / s`` tokens and there are ``waves`` of them, so dividing
        by ``s`` makes more splits cheaper *until* they spill into extra waves.
      - ``mu * waves`` charges per-wave launch/tail overhead so we do not
        over-split into many mostly-idle waves (e.g. batch 224 on 256 CUs is
        best left at 1 split rather than 8 splits across 7 waves).

    The minimiser naturally prefers split counts that pack the device into full
    waves (``base * s`` near a multiple of ``CU``) and falls back to 1 split
    once the batch already fills the device. Ties favour the smaller split
    count (less reduce work).

    Finally we "snap down" the chosen split count to the smallest value that
    yields the same wave count *and* the same per-workgroup BLOCK_K iteration
    count. Because latency tracks iteration count (not raw token count), extra
    splits that do not lower the iteration count add only reduce/HBM overhead
    for no parallelism gain (e.g. batch 24: s8 and s10 both walk 4 extra iters
    in one wave, so s8 is strictly better). Snapping needs the average segment
    lengths, which the caller derives sync-free from the ragged index sizes.
    """
    base = max(1, num_queries * heads_blocks)
    # Target ~1 workgroup per CU: enough to fill the device while keeping the
    # reduce cost (which grows with split count) small. Tuned on gfx950.
    cu = max(1, _decode_cu_count())
    # Per-wave overhead penalty: higher values discourage split counts that
    # spill into extra GPU waves. Tuned on gfx950.
    mu = 0.04
    best_splits = 1
    best_cost = None
    # Search up to 16 splits; beyond that the reduce/HBM overhead dominates.
    for splits in range(1, 17):
        waves = (base * splits + cu - 1) // cu
        cost = waves * (1.0 / splits + mu)
        if best_cost is None or cost < best_cost - 1e-9:
            best_splits = splits
            best_cost = cost

    if best_splits > 1 and (avg_main_len > 0 or avg_extra_len > 0):
        target_waves = (base * best_splits + cu - 1) // cu
        target_iters = _decode_partial_iters(
            avg_main_len, avg_extra_len, best_splits, block_k
        )
        for splits in range(1, best_splits):
            waves = (base * splits + cu - 1) // cu
            iters = _decode_partial_iters(avg_main_len, avg_extra_len, splits, block_k)
            if waves == target_waves and iters == target_iters:
                best_splits = splits
                break
    return best_splits

_decode_partial_iters(avg_main_len, avg_extra_len, splits, block_k)

BLOCK_K iterations one partial workgroup walks for splits splits.

Each split processes ceil(seg_len / splits) tokens of a segment, walked BLOCK_K at a time, and the main/extra segments are handled separately.

Source code in vllm/v1/attention/ops/rocm_aiter_mla_sparse.py
def _decode_partial_iters(
    avg_main_len: float, avg_extra_len: float, splits: int, block_k: int
) -> int:
    """BLOCK_K iterations one partial workgroup walks for ``splits`` splits.

    Each split processes ``ceil(seg_len / splits)`` tokens of a segment, walked
    ``BLOCK_K`` at a time, and the main/extra segments are handled separately.
    """
    main_iters = (
        math.ceil(math.ceil(avg_main_len / splits) / block_k) if avg_main_len > 0 else 0
    )
    extra_iters = (
        math.ceil(math.ceil(avg_extra_len / splits) / block_k)
        if avg_extra_len > 0
        else 0
    )
    return main_iters + extra_iters

_fused_inverse_rope_gptj(o, positions, cos_sin_cache, rope_head_dim)

bf16 inverse GPT-J RoPE via a single fused Triton kernel.

Source code in vllm/v1/attention/ops/rocm_aiter_mla_sparse.py
def _fused_inverse_rope_gptj(
    o: torch.Tensor,
    positions: torch.Tensor,
    cos_sin_cache: torch.Tensor,
    rope_head_dim: int,
) -> torch.Tensor:
    """bf16 inverse GPT-J RoPE via a single fused Triton kernel."""
    assert o.dim() == 3 and o.stride(-1) == 1, (
        "_fused_inverse_rope_gptj expects a [T, H, D] input with a contiguous last dim"
    )
    assert rope_head_dim > 0 and rope_head_dim % 2 == 0, (
        f"_fused_inverse_rope_gptj expects an even rope_head_dim, got {rope_head_dim}"
    )
    assert cos_sin_cache.shape[-1] == rope_head_dim, (
        "_fused_inverse_rope_gptj expects cos_sin_cache laid out as "
        f"[P, {rope_head_dim}] = cos | sin, got {tuple(cos_sin_cache.shape)}"
    )
    num_tokens, num_heads, head_dim = o.shape
    out = torch.empty(
        (num_tokens, num_heads, head_dim), dtype=torch.bfloat16, device=o.device
    )
    if num_tokens == 0:
        return out
    _inverse_rope_gptj_kernel[(num_tokens, num_heads)](
        o,
        out,
        positions,
        cos_sin_cache,
        o.stride(0),
        o.stride(1),
        out.stride(0),
        out.stride(1),
        cos_sin_cache.stride(0),
        NOPE=head_dim - rope_head_dim,
        HALF=rope_head_dim // 2,
        BLOCK_NOPE=triton.next_power_of_2(head_dim - rope_head_dim),
        BLOCK_HALF=triton.next_power_of_2(rope_head_dim // 2),
    )
    return out

_get_cached_wo_a_bf16(wo_a, n_local_groups, o_lora_rank, hidden_dim)

Dequantize wo_a to bf16 once and cache it on the module.

wo_a weights are static, so the fp8 -> fp32 -> (* block scale) -> bf16 dequant only needs to run once. Recomputing it every decode step shows up in the profile as the largest copy/mul kernels (direct_copy float ~55us and MulFunctor float ~31us per two layers). SGLang / ATOM keep wo_a in bf16 and feed a plain bf16 GEMM; this mirrors that.

Source code in vllm/v1/attention/ops/rocm_aiter_mla_sparse.py
def _get_cached_wo_a_bf16(
    wo_a: torch.nn.Module,
    n_local_groups: int,
    o_lora_rank: int,
    hidden_dim: int,
) -> torch.Tensor:
    """Dequantize wo_a to bf16 once and cache it on the module.

    wo_a weights are static, so the fp8 -> fp32 -> (* block scale) -> bf16
    dequant only needs to run once. Recomputing it every decode step shows up
    in the profile as the largest copy/mul kernels (``direct_copy float`` ~55us
    and ``MulFunctor float`` ~31us per two layers). SGLang / ATOM keep wo_a in
    bf16 and feed a plain bf16 GEMM; this mirrors that.
    """
    cached = getattr(wo_a, "_dsv4_wo_a_bf16", None)
    if cached is not None:
        return cached
    if hasattr(wo_a, "weight_scale_inv"):
        wo_a_weight = wo_a.weight.view(n_local_groups, o_lora_rank, hidden_dim).to(
            torch.float32
        )
        wo_a_scale = _expand_2d_block_scales(
            wo_a.weight_scale_inv.view(
                n_local_groups, -1, wo_a.weight_scale_inv.shape[-1]
            ),
            o_lora_rank,
            hidden_dim,
        )
        cached = (wo_a_weight * wo_a_scale).to(torch.bfloat16)
    else:
        cached = wo_a.weight.view(n_local_groups, o_lora_rank, hidden_dim).to(
            torch.bfloat16
        )
    wo_a._dsv4_wo_a_bf16 = cached
    return cached

_inverse_rope_gptj_kernel(o_ptr, out_ptr, pos_ptr, cos_sin_ptr, s_t, s_h, os_t, os_h, cs_stride, NOPE, HALF, BLOCK_NOPE, BLOCK_HALF)

Fused inverse GPT-J RoPE on the trailing rope_dim of each (token, head).

Mirrors DeepseekV4ScalingRotaryEmbedding.forward_native(inverse=True) for the GPT-J (non-neox) layout, writing bf16 directly. Replaces the clone + index_select + repeat_interleave + neg + stack + cat + cast chain (~10 small kernels) with a single launch.

Source code in vllm/v1/attention/ops/rocm_aiter_mla_sparse.py
@triton.jit
def _inverse_rope_gptj_kernel(
    o_ptr,  # [T, H, D] input
    out_ptr,  # [T, H, D] bf16 output
    pos_ptr,  # [T] positions
    cos_sin_ptr,  # [P, rope_dim] fp32 (cos[:half] | sin[half:])
    s_t,
    s_h,  # input row strides (last dim contiguous)
    os_t,
    os_h,  # output row strides
    cs_stride,  # cos_sin_cache row stride
    NOPE: tl.constexpr,  # non-rope head dims (passed through)
    HALF: tl.constexpr,  # rope_dim // 2
    BLOCK_NOPE: tl.constexpr,
    BLOCK_HALF: tl.constexpr,
):
    """Fused inverse GPT-J RoPE on the trailing rope_dim of each (token, head).

    Mirrors ``DeepseekV4ScalingRotaryEmbedding.forward_native(inverse=True)``
    for the GPT-J (non-neox) layout, writing bf16 directly. Replaces the
    clone + index_select + repeat_interleave + neg + stack + cat + cast chain
    (~10 small kernels) with a single launch.
    """
    t = tl.program_id(0)
    h = tl.program_id(1)
    in_base = t * s_t + h * s_h
    out_base = t * os_t + h * os_h

    # NoPE lanes pass through unchanged (only cast to bf16).
    n = tl.arange(0, BLOCK_NOPE)
    nmask = n < NOPE
    vals = tl.load(o_ptr + in_base + n, mask=nmask)
    tl.store(out_ptr + out_base + n, vals.to(tl.bfloat16), mask=nmask)

    # RoPE lanes: out_even = a*cos + b*sin, out_odd = b*cos - a*sin
    # (a = even lane, b = odd lane; sin negated for the inverse rotation).
    pos = tl.load(pos_ptr + t).to(tl.int64)
    k = tl.arange(0, BLOCK_HALF)
    kmask = k < HALF
    a = tl.load(o_ptr + in_base + NOPE + 2 * k, mask=kmask).to(tl.float32)
    b = tl.load(o_ptr + in_base + NOPE + 2 * k + 1, mask=kmask).to(tl.float32)
    cos = tl.load(cos_sin_ptr + pos * cs_stride + k, mask=kmask)
    sin = tl.load(cos_sin_ptr + pos * cs_stride + HALF + k, mask=kmask)
    out_even = a * cos + b * sin
    out_odd = b * cos - a * sin
    tl.store(out_ptr + out_base + NOPE + 2 * k, out_even.to(tl.bfloat16), mask=kmask)
    tl.store(out_ptr + out_base + NOPE + 2 * k + 1, out_odd.to(tl.bfloat16), mask=kmask)

fp8_mqa_logits_torch(q, kv, weights, cu_seqlen_ks, cu_seqlen_ke)

Compute FP8 MQA logits for a single sequence without KV paging.

Parameters:

  • q

    (Tensor) –

    Query tensor of shape [M, H, D]. Casted to torch.float8_e4m3fn by caller.

  • kv

    (tuple[Tensor, Tensor]) –

    Tuple (k_fp8, k_scales) where k_fp8 has shape [N, D] with dtype torch.float8_e4m3fn and k_scales has shape [N] (or [N, 1]) with dtype torch.float32.

  • weights

    (Tensor) –

    weights of shape [M, H], dtype torch.float32.

  • cu_seqlen_ks

    (Tensor) –

    Start indices (inclusive) for valid K per query position, shape [M], dtype int32.

  • cu_seqlen_ke

    (Tensor) –

    End indices (exclusive) for valid K per query position, shape [M], dtype int32.

Returns:

  • Tensor

    Logits tensor of shape [M, N], dtype torch.float32.

Source code in vllm/v1/attention/ops/rocm_aiter_mla_sparse.py
def fp8_mqa_logits_torch(
    q: torch.Tensor,
    kv: tuple[torch.Tensor, torch.Tensor],
    weights: torch.Tensor,
    cu_seqlen_ks: torch.Tensor,
    cu_seqlen_ke: torch.Tensor,
) -> torch.Tensor:
    """Compute FP8 MQA logits for a single sequence without KV paging.

    Args:
        q: Query tensor of shape [M, H, D]. Casted to
            `torch.float8_e4m3fn` by caller.
        kv: Tuple `(k_fp8, k_scales)` where `k_fp8` has shape [N, D] with
            dtype `torch.float8_e4m3fn` and `k_scales` has shape [N] (or
            [N, 1]) with dtype `torch.float32`.
        weights: weights of shape [M, H], dtype `torch.float32`.
        cu_seqlen_ks: Start indices (inclusive) for valid K per query position,
            shape [M], dtype int32.
        cu_seqlen_ke: End indices (exclusive) for valid K per query position,
            shape [M], dtype int32.

    Returns:
        Logits tensor of shape [M, N], dtype `torch.float32`.
    """
    k_fp8, scale = kv
    seq_len_kv = k_fp8.shape[0]
    k = k_fp8.to(torch.bfloat16)
    q = q.to(torch.bfloat16)
    device = q.device

    mask_lo = (
        torch.arange(0, seq_len_kv, device=device)[None, :] >= cu_seqlen_ks[:, None]
    )
    mask_hi = (
        torch.arange(0, seq_len_kv, device=device)[None, :] < cu_seqlen_ke[:, None]
    )
    mask = mask_lo & mask_hi

    score = torch.einsum("mhd,nd->hmn", q, k).float() * scale
    logits = (score.relu() * weights.unsqueeze(-1).transpose(0, 1)).sum(dim=0)
    logits = logits.masked_fill(~mask, float("-inf"))

    return logits

rocm_fp8_mqa_logits(q, kv, weights, cu_seqlen_ks, cu_seqlen_ke)

Compute FP8 MQA logits for a single sequence without KV paging.

Parameters:

  • q

    (Tensor) –

    Query tensor of shape [M, H, D]. Casted to torch.float8_e4m3fn by caller.

  • kv

    (tuple[Tensor, Tensor]) –

    Tuple (k_fp8, k_scales) where k_fp8 has shape [N, D] with dtype torch.float8_e4m3fn and k_scales has shape [N] (or [N, 1]) with dtype torch.float32.

  • weights

    (Tensor) –

    weights of shape [M, H], dtype torch.float32.

  • cu_seqlen_ks

    (Tensor) –

    Start indices (inclusive) for valid K per query position, shape [M], dtype int32.

  • cu_seqlen_ke

    (Tensor) –

    End indices (exclusive) for valid K per query position, shape [M], dtype int32.

Returns:

  • Tensor

    Logits tensor of shape [M, N], dtype torch.float32.

Source code in vllm/v1/attention/ops/rocm_aiter_mla_sparse.py
def rocm_fp8_mqa_logits(
    q: torch.Tensor,
    kv: tuple[torch.Tensor, torch.Tensor],
    weights: torch.Tensor,
    cu_seqlen_ks: torch.Tensor,
    cu_seqlen_ke: torch.Tensor,
) -> torch.Tensor:
    """Compute FP8 MQA logits for a single sequence without KV paging.

    Args:
        q: Query tensor of shape [M, H, D]. Casted to
            `torch.float8_e4m3fn` by caller.
        kv: Tuple `(k_fp8, k_scales)` where `k_fp8` has shape [N, D] with
            dtype `torch.float8_e4m3fn` and `k_scales` has shape [N] (or
            [N, 1]) with dtype `torch.float32`.
        weights: weights of shape [M, H], dtype `torch.float32`.
        cu_seqlen_ks: Start indices (inclusive) for valid K per query position,
            shape [M], dtype int32.
        cu_seqlen_ke: End indices (exclusive) for valid K per query position,
            shape [M], dtype int32.

    Returns:
        Logits tensor of shape [M, N], dtype `torch.float32`.
    """

    # TODO(ganyi): Temporarily workaround, will remove the module check and reference
    # path after aiter merge this kernel into main
    from vllm._aiter_ops import rocm_aiter_ops

    aiter_mqa_logits_module = None
    if rocm_aiter_ops.is_enabled():
        aiter_mqa_logits_module = mqa_logits_module()

    if aiter_mqa_logits_module is not None:
        fp8_mqa_logits = aiter_mqa_logits_module.fp8_mqa_logits
        k_fp8, scale = kv
        return fp8_mqa_logits(q, k_fp8, scale, weights, cu_seqlen_ks, cu_seqlen_ke)
    else:
        return fp8_mqa_logits_torch(q, kv, weights, cu_seqlen_ks, cu_seqlen_ke)

rocm_fp8_paged_mqa_logits(q_fp8, kv_cache_fp8, weights, context_lens, block_tables, schedule_metadata, max_model_len)

Compute FP8 MQA logits using paged KV-cache.

Parameters:

  • q_fp8

    (Tensor) –

    Query tensor of shape [B, next_n, H, D]. Casted to torch.float8_e4m3fn by caller.

  • kv_cache_fp8

    (Tensor) –

    Paged KV-cache in packed FP8+scale layout with shape [num_blocks, block_size, 1, D+4], dtype torch.uint8. The last 4 bytes per (block,pos) store the float dequant scale.

  • weights

    (Tensor) –

    Tensor of shape [B * next_n, H], dtype torch.float32.

  • context_lens

    (Tensor) –

    Tensor of shape [B], dtype int32; effective context length for each batch element.

  • block_tables

    (Tensor) –

    Tensor of shape [B, max_blocks], dtype int32; maps logical block indices to physical blocks in the paged cache.

  • schedule_metadata

    (Tensor) –

    Returned by get_paged_mqa_logits_metadata; used to distribute work across SMs.

  • max_model_len

    (int) –

    Maximum sequence length used to size the logits output.

Returns:

  • Tensor

    Logits tensor of shape [B * next_n, max_model_len], dtype

  • Tensor

    torch.float32.

Source code in vllm/v1/attention/ops/rocm_aiter_mla_sparse.py
def rocm_fp8_paged_mqa_logits(
    q_fp8: torch.Tensor,
    kv_cache_fp8: torch.Tensor,
    weights: torch.Tensor,
    context_lens: torch.Tensor,
    block_tables: torch.Tensor,
    schedule_metadata: torch.Tensor,
    max_model_len: int,
) -> torch.Tensor:
    """Compute FP8 MQA logits using paged KV-cache.

    Args:
        q_fp8: Query tensor of shape [B, next_n, H, D]. Casted to
            `torch.float8_e4m3fn` by caller.
        kv_cache_fp8: Paged KV-cache in packed FP8+scale layout with shape
            [num_blocks, block_size, 1, D+4], dtype `torch.uint8`. The last
            4 bytes per (block,pos) store the `float` dequant scale.
        weights: Tensor of shape [B * next_n, H], dtype `torch.float32`.
        context_lens: Tensor of shape [B], dtype int32; effective context length
            for each batch element.
        block_tables: Tensor of shape [B, max_blocks], dtype int32; maps logical
            block indices to physical blocks in the paged cache.
        schedule_metadata: Returned by `get_paged_mqa_logits_metadata`;
            used to distribute work across SMs.
        max_model_len: Maximum sequence length used to size the logits output.

    Returns:
        Logits tensor of shape [B * next_n, max_model_len], dtype
        `torch.float32`.
    """
    from vllm._aiter_ops import rocm_aiter_ops

    aiter_paged_mqa_logits_module = None
    # if rocm_aiter_ops.is_enabled():
    batch_size, next_n = q_fp8.shape[:2]
    block_size = kv_cache_fp8.shape[1]

    if rocm_aiter_ops.is_enabled():
        aiter_paged_mqa_logits_module = paged_mqa_logits_module()

    if aiter_paged_mqa_logits_module is not None:
        if _ON_GFX942 or _ON_GFX950:
            deepgemm_fp8_paged_mqa_logits = (
                aiter_paged_mqa_logits_module.deepgemm_fp8_paged_mqa_logits
            )
            batch_size, next_n, heads, _ = q_fp8.shape
            (out_logits,) = current_workspace_manager().get_simultaneous(
                ((batch_size * next_n, max_model_len), torch.float32),
            )
            out_logits.fill_(float("-inf"))
            deepgemm_fp8_paged_mqa_logits(
                q_fp8,
                kv_cache_fp8,
                weights,
                out_logits,
                context_lens,
                block_tables,
                max_model_len,
                ChunkK=256,
                Preshuffle=block_size > 1,
                KVBlockSize=block_size,
                WavePerEU=2,
            )
            return out_logits
        deepgemm_fp8_paged_mqa_logits_stage1 = (
            aiter_paged_mqa_logits_module.deepgemm_fp8_paged_mqa_logits_stage1
        )
        batch_size, next_n, heads, _ = q_fp8.shape
        (out_qk,) = current_workspace_manager().get_simultaneous(
            ((heads, batch_size * next_n, max_model_len), torch.float32),
        )
        out_qk.fill_(float("-inf"))
        deepgemm_fp8_paged_mqa_logits_stage1(
            q_fp8,
            kv_cache_fp8,
            weights,
            out_qk,
            context_lens,
            block_tables,
            max_model_len,
            ChunkQ=heads,
        )
        return out_qk.sum(dim=0)
    else:
        return fp8_paged_mqa_logits_torch(
            q_fp8, kv_cache_fp8, weights, context_lens, block_tables, max_model_len
        )

rocm_inv_rope_einsum(rotary_emb, o, positions, rope_head_dim, n_local_groups, o_lora_rank, wo_a)

Inverse-RoPE + WO_A bmm path used on ROCm.

Fuses the inverse GPT-J RoPE into one Triton kernel and caches the bf16 wo_a weight so the per-step dequant disappears.

Source code in vllm/v1/attention/ops/rocm_aiter_mla_sparse.py
def rocm_inv_rope_einsum(
    rotary_emb: torch.nn.Module,
    o: torch.Tensor,
    positions: torch.Tensor,
    rope_head_dim: int,
    n_local_groups: int,
    o_lora_rank: int,
    wo_a: torch.nn.Module,
) -> torch.Tensor:
    """Inverse-RoPE + WO_A bmm path used on ROCm.

    Fuses the inverse GPT-J RoPE into one Triton kernel and caches the bf16
    wo_a weight so the per-step dequant disappears.
    """
    o_ref = _fused_inverse_rope_gptj(
        o, positions, rotary_emb.cos_sin_cache, rope_head_dim
    )
    o_ref = o_ref.view(o.shape[0], n_local_groups, -1)

    wo_a_weight = _get_cached_wo_a_bf16(
        wo_a, n_local_groups, o_lora_rank, o_ref.shape[-1]
    )

    return torch.einsum("tgd,grd->tgr", o_ref, wo_a_weight)