Skip to content

vllm.models.minimax_m3.amd.ops.sparse_attn

ROCm gfx942/gfx950 block-sparse GQA prefill kernel for MiniMax-M3.

Only the prefill path is specialized on CDNA: each 128-token KV block is split into SUB_K-token sub-tiles to right-size the per-block QK/PV MFMAs. Everything else -- the decode split-K kernels, the FP8 dtype set, the sparse block size -- is reused unchanged from common.ops.sparse_attn.

Functions:

_sparse_attn_prefill_kwargs()

MFMA + pipeline launch params for the sub-tiled prefill kernel.

gfx942 and gfx950 share the same params: num_warps=1 keeps one wave resident on the small per-sub-tile GEMM, matrix_instr_nonkdim=16 / kpack=2 select the MFMA_16x16 path, and num_stages=1 fits LDS and is fastest in the sweep. Only the sub-tile width (_SPARSE_ATTN_SUB_K) differs by arch. Empty on other AMD archs. Cached: arch is fixed per process.

Source code in vllm/models/minimax_m3/amd/ops/sparse_attn.py
def _sparse_attn_prefill_kwargs() -> dict:
    """MFMA + pipeline launch params for the sub-tiled prefill kernel.

    gfx942 and gfx950 share the same params: ``num_warps=1`` keeps one wave
    resident on the small per-sub-tile GEMM, ``matrix_instr_nonkdim=16`` /
    ``kpack=2`` select the MFMA_16x16 path, and ``num_stages=1`` fits LDS and is
    fastest in the sweep. Only the sub-tile width (``_SPARSE_ATTN_SUB_K``)
    differs by arch. Empty on other AMD archs. Cached: arch is fixed per process.
    """
    global _SPARSE_ATTN_PREFILL_KWARG
    if _SPARSE_ATTN_PREFILL_KWARG is None:
        kwarg: dict = {}
        if on_mi3xx():
            kwarg = {
                "num_warps": 1,
                "matrix_instr_nonkdim": 16,
                "kpack": 2,
                "num_stages": 1,
            }
        _SPARSE_ATTN_PREFILL_KWARG = kwarg
    return _SPARSE_ATTN_PREFILL_KWARG

minimax_m3_sparse_attn(q, kv_cache, topk_idx, block_table, cu_seqlens_q, seq_lens, prefix_lens, max_query_len, num_kv_heads, sm_scale, output)

GQA block-sparse attention over the selected blocks. block_size_q == 1.

Source code in vllm/models/minimax_m3/amd/ops/sparse_attn.py
@torch.no_grad()
def minimax_m3_sparse_attn(
    q: torch.Tensor,  # [total_q, num_heads, head_dim]
    kv_cache: torch.Tensor,  # [num_blocks, 2, 128, num_kv_heads, head_dim]
    topk_idx: torch.Tensor,  # [num_kv_heads, total_q, topk]
    block_table: torch.Tensor,  # [batch, max_blocks]
    cu_seqlens_q: torch.Tensor,  # [batch+1] int32
    seq_lens: torch.Tensor,  # [batch] int32
    prefix_lens: torch.Tensor,  # [batch] int32
    max_query_len: int,
    num_kv_heads: int,
    sm_scale: float,
    output: torch.Tensor,  # [total_q, num_heads, head_dim]
) -> None:
    """GQA block-sparse attention over the selected blocks. block_size_q == 1."""
    total_q, num_heads, head_dim = q.shape
    batch = cu_seqlens_q.shape[0] - 1
    topk = topk_idx.shape[-1]
    gqa_group_size = num_heads // num_kv_heads
    use_fp8 = kv_cache.dtype in _FP8_DTYPES
    grid = (max_query_len, num_kv_heads, batch)
    _gqa_sparse_fwd_kernel[grid](
        q,
        kv_cache,
        topk_idx,
        output,
        block_table,
        cu_seqlens_q,
        cu_seqlens_q,  # cu_seqblocks_q == cu_seqlens_q when block_size_q == 1
        seq_lens,
        prefix_lens,
        num_kv_heads,
        gqa_group_size,
        head_dim,
        topk,
        1,  # num_q_loop
        sm_scale,
        q.stride(0),
        q.stride(1),
        q.stride(2),
        kv_cache.stride(0),
        kv_cache.stride(1),
        kv_cache.stride(2),
        kv_cache.stride(3),
        kv_cache.stride(4),
        topk_idx.stride(0),
        topk_idx.stride(1),
        topk_idx.stride(2),
        output.stride(0),
        output.stride(1),
        output.stride(2),
        block_table.stride(0),
        BLOCK_SIZE_Q=1,
        BLOCK_SIZE_K=SPARSE_BLOCK_SIZE,
        USE_FP8=use_fp8,
        SUB_K=_SPARSE_ATTN_SUB_K,
        **_sparse_attn_prefill_kwargs(),
    )

minimax_m3_sparse_attn_decode(q, kv_cache, topk_idx, block_table, seq_lens, num_kv_heads, sm_scale, output, decode_query_len)

GQA block-sparse attention for decode (split-K over the top-k blocks).

Source code in vllm/models/minimax_m3/common/ops/sparse_attn.py
@torch.no_grad()
def minimax_m3_sparse_attn_decode(
    q: torch.Tensor,  # [total_q, num_heads, head_dim]
    kv_cache: torch.Tensor,  # [num_blocks, 2, 128, num_kv_heads, head_dim]
    topk_idx: torch.Tensor,  # [num_kv_heads, total_q, topk]
    block_table: torch.Tensor,  # [num_reqs, max_blocks]
    seq_lens: torch.Tensor,  # [num_reqs] int32
    num_kv_heads: int,
    sm_scale: float,
    output: torch.Tensor,  # [total_q, num_heads, head_dim]
    decode_query_len: int,
) -> None:
    """GQA block-sparse attention for decode (split-K over the top-k blocks)."""
    total_q, num_heads, head_dim = q.shape
    assert total_q == seq_lens.shape[0] * decode_query_len
    max_topk = topk_idx.shape[-1]
    gqa_group_size = num_heads // num_kv_heads
    use_fp8 = kv_cache.dtype in _FP8_DTYPES
    use_pdl = current_platform.is_arch_support_pdl()
    # `launch_pdl` is a Triton runtime kwarg only some backends accept (CUDA
    # SM9+); this ROCm Triton rejects it even when False ("Keyword argument
    # launch_pdl was specified but unrecognised"). Only pass it when PDL is
    # actually supported -- on ROCm use_pdl is always False, so it's omitted.
    pdl_launch = {"launch_pdl": True} if use_pdl else {}
    # split-K over the selected blocks; chunk count is shape-constant (cuda graph).
    TARGET_GRID = 256
    target = max(1, min(max_topk, TARGET_GRID // max(1, total_q * num_kv_heads)))
    num_topk_chunks = 1 << (target.bit_length() - 1)
    o_partial = torch.empty(
        num_topk_chunks, total_q, num_heads, head_dim, dtype=q.dtype, device=q.device
    )
    lse_partial = torch.empty(
        num_topk_chunks, total_q, num_heads, dtype=torch.float32, device=q.device
    )
    grid = (total_q * num_topk_chunks, num_kv_heads)
    _gqa_sparse_decode_kernel[grid](
        q,
        kv_cache,
        topk_idx,
        o_partial,
        lse_partial,
        block_table,
        seq_lens,
        total_q,
        gqa_group_size,
        head_dim,
        max_topk,
        sm_scale,
        decode_query_len,
        q.stride(0),
        q.stride(1),
        q.stride(2),
        kv_cache.stride(0),
        kv_cache.stride(1),
        kv_cache.stride(2),
        kv_cache.stride(3),
        kv_cache.stride(4),
        topk_idx.stride(0),
        topk_idx.stride(1),
        topk_idx.stride(2),
        o_partial.stride(0),
        o_partial.stride(1),
        o_partial.stride(2),
        o_partial.stride(3),
        lse_partial.stride(0),
        lse_partial.stride(1),
        lse_partial.stride(2),
        block_table.stride(0),
        BLOCK_SIZE_K=SPARSE_BLOCK_SIZE,
        NUM_TOPK_CHUNKS=num_topk_chunks,
        USE_FP8=use_fp8,
        USE_PDL=use_pdl,
        **pdl_launch,
    )
    merge_grid = (total_q, num_heads)
    _merge_topk_attn_out_kernel[merge_grid](
        o_partial,
        lse_partial,
        output,
        head_dim,
        o_partial.stride(0),
        o_partial.stride(1),
        o_partial.stride(2),
        o_partial.stride(3),
        lse_partial.stride(0),
        lse_partial.stride(1),
        lse_partial.stride(2),
        output.stride(0),
        output.stride(1),
        output.stride(2),
        NUM_TOPK_CHUNKS=num_topk_chunks,
        USE_PDL=use_pdl,
        **pdl_launch,
    )