Skip to content

vllm.v1.attention.ops.triton_fp8_mqa_logits

Temporary gfx942 fallback for AITER's fp8_mqa_logits kernel.

This module vendors AITER's Triton fp8_mqa_logits kernel with the gfx942 tile-size workaround from ROCm/aiter#3257. It is used only while vLLM's pinned AITER version lacks that fix.

TODO: Remove this vendored copy once vLLM pins an AITER version that includes ROCm/aiter#3257 bugfix for gfx942.

Functions:

_gfx942_default_tile_fits_lds(num_heads, head_size)

Return True iff (BLOCK_KV=128, num_stages=2) fits in MI300X LDS.

Source code in vllm/v1/attention/ops/triton_fp8_mqa_logits.py
def _gfx942_default_tile_fits_lds(num_heads: int, head_size: int) -> bool:
    """Return True iff (BLOCK_KV=128, num_stages=2) fits in MI300X LDS."""
    BLOCK_KV = 128
    NUM_STAGES = 2
    kv_bytes = head_size * BLOCK_KV * NUM_STAGES
    scores_bytes = num_heads * BLOCK_KV * 4
    q_bytes = num_heads * head_size
    fits_occupancy = kv_bytes < _GFX942_PER_WG_LDS_BUDGET_BYTES
    fits_hardware = q_bytes + kv_bytes + scores_bytes <= _GFX942_CU_LDS_BYTES
    return fits_occupancy and fits_hardware

fp8_mqa_logits_gfx942(q, k_fp8, kv_scales, weights, cu_starts, cu_ends)

Compute FP8 MQA logits on MI300X (gfx942) using the vendored kernel.

Drop-in replacement for aiter.ops.triton.attention.fp8_mqa_logits. fp8_mqa_logits on MI300X. Selects (BLOCK_KV, num_stages) based on whether the default tile fits within the 64 KiB LDS budget of a gfx942 CU (see module docstring).

Parameters:

  • q

    (Tensor) –

    Query tensor of shape [M, H, D], FP8 dtype.

  • k_fp8

    (Tensor) –

    Key tensor of shape [N, D], FP8 dtype.

  • kv_scales

    (Tensor) –

    K scales of shape [N] (or [N, 1] -- viewed as [N]), float32.

  • weights

    (Tensor) –

    Per-head weights of shape [M, H], float32.

  • cu_starts

    (Tensor) –

    Start indices (inclusive) of shape [M], int32.

  • cu_ends

    (Tensor) –

    End indices (exclusive) of shape [M], int32.

Returns:

  • Tensor

    Logits of shape [M, N], float32 -- positions outside

  • Tensor

    [cu_starts[i], cu_ends[i]) for row i are pre-filled with

  • Tensor

    -inf so the caller can run a top-k without masking.

Source code in vllm/v1/attention/ops/triton_fp8_mqa_logits.py
def fp8_mqa_logits_gfx942(
    q: torch.Tensor,
    k_fp8: torch.Tensor,
    kv_scales: torch.Tensor,
    weights: torch.Tensor,
    cu_starts: torch.Tensor,
    cu_ends: torch.Tensor,
) -> torch.Tensor:
    """Compute FP8 MQA logits on MI300X (gfx942) using the vendored kernel.

    Drop-in replacement for ``aiter.ops.triton.attention.fp8_mqa_logits.
    fp8_mqa_logits`` on MI300X. Selects ``(BLOCK_KV, num_stages)`` based on
    whether the default tile fits within the 64 KiB LDS budget of a gfx942
    CU (see module docstring).

    Args:
        q: Query tensor of shape ``[M, H, D]``, FP8 dtype.
        k_fp8: Key tensor of shape ``[N, D]``, FP8 dtype.
        kv_scales: K scales of shape ``[N]`` (or ``[N, 1]`` -- viewed as
            ``[N]``), float32.
        weights: Per-head weights of shape ``[M, H]``, float32.
        cu_starts: Start indices (inclusive) of shape ``[M]``, int32.
        cu_ends: End indices (exclusive) of shape ``[M]``, int32.

    Returns:
        Logits of shape ``[M, N]``, float32 -- positions outside
        ``[cu_starts[i], cu_ends[i])`` for row ``i`` are pre-filled with
        ``-inf`` so the caller can run a top-k without masking.
    """
    seq_len, num_heads, head_size = q.shape
    seq_len_kv = k_fp8.shape[0]
    assert num_heads & (num_heads - 1) == 0, (
        f"num_heads must be a power of two (got {num_heads})"
    )
    assert head_size & (head_size - 1) == 0, (
        f"head_size must be a power of two (got {head_size})"
    )

    # The kernel walks ``kv_scales`` as a 1-D contiguous array of size N
    # (it indexes by ``kv_scales_ptr + kv_col_offsets``). The vLLM caller
    # passes a ``[N, 4]`` uint8 view-cast-to-float32 which lands as
    # ``[N, 1]`` contiguous -- byte-identical to ``[N]`` -- but flatten
    # explicitly to keep the kernel's pointer arithmetic intent clear.
    kv_scales_1d = kv_scales.reshape(-1)

    # Initialise with -inf so positions outside [cu_starts, cu_ends) read
    # as ``-inf`` after the masked store path -- this matches AITER's
    # ``fp8_mqa_logits`` semantics and is what the top-k consumer expects.
    logits = torch.full(
        (seq_len, seq_len_kv),
        fill_value=-float("inf"),
        dtype=torch.float32,
        device=q.device,
    )

    if _gfx942_default_tile_fits_lds(num_heads, head_size):
        block_kv = 128
        num_stages = 2
    else:
        # DSv4 sparse indexer (NUM_HEADS=64, HEAD_SIZE=128) lands here:
        # default tile spills past gfx942's 64 KiB LDS budget. (64, 1)
        # needs ~33 KiB and clears the per-WG budget with margin.
        block_kv = 64
        num_stages = 1

    # heuristic for MFMA instruction shape, identical to AITER's choice
    matrix_instr_nonkdim = 32
    if seq_len <= 1024:
        matrix_instr_nonkdim = 16

    stride_q_s, stride_q_h, stride_q_d = q.stride()
    stride_kv_s, stride_kv_d = k_fp8.stride()
    stride_w_s, stride_w_h = weights.stride()
    stride_logits_s, stride_logits_k = logits.stride()

    _fp8_mqa_logits_kernel[(seq_len,)](
        Q_ptr=q,
        KV_ptr=k_fp8,
        kv_scales_ptr=kv_scales_1d,
        weights_ptr=weights,
        cu_start_ptr=cu_starts,
        cu_end_ptr=cu_ends,
        logits_ptr=logits,
        seq_len=seq_len,
        seq_len_kv=seq_len_kv,
        NUM_HEADS=num_heads,
        HEAD_SIZE=head_size,
        stride_q_s=stride_q_s,
        stride_q_h=stride_q_h,
        stride_q_d=stride_q_d,
        stride_kv_s=stride_kv_s,
        stride_kv_d=stride_kv_d,
        stride_w_s=stride_w_s,
        stride_w_h=stride_w_h,
        stride_logits_s=stride_logits_s,
        stride_logits_k=stride_logits_k,
        BLOCK_KV=block_kv,
        num_warps=4,
        num_stages=num_stages,
        waves_per_eu=2,
        matrix_instr_nonkdim=matrix_instr_nonkdim,
    )

    return logits