Skip to content

vllm.models.minimax_m3.common.sparse_attention

Main block-sparse GQA attention for MiniMax M3 sparse layers.

The lightning indexer (indexer.py) selects the top-k KV blocks (written into the shared layer.topk_indices_buffer); this module holds the main attention that attends only to those blocks: the paged K/V cache backend, its metadata + builder, and the impl that reads the indexer's top-k from that buffer. The Triton attend kernel lives here; the SM100 (MSA) build_k2q_csr + sparse_atten_func attend lives in nvidia/sparse_attention_msa.py.

MiniMaxM3SparseBackend and MiniMaxM3SparseMetadata are referenced by the attention-backend registry (by dotted path) and by spec-decode, so they must keep these names and stay in this module.

Classes:

Functions:

MiniMaxM3SparseBackend

Bases: AttentionBackend

Block-sparse GQA backend for MiniMax M3 sparse attention layers.

Source code in vllm/models/minimax_m3/common/sparse_attention.py
class MiniMaxM3SparseBackend(AttentionBackend):
    """Block-sparse GQA backend for MiniMax M3 sparse attention layers."""

    supported_dtypes: ClassVar[list[torch.dtype]] = [torch.bfloat16, torch.float16]
    # bf16 or fp8 (e4m3/e5m2): the Triton kernels dequant fp8 before the dots.
    supported_kv_cache_dtypes: ClassVar[list[CacheDType]] = [
        "bfloat16",
        "fp8",
        "fp8_e4m3",
        "fp8_e5m2",
    ]

    @staticmethod
    def get_name() -> str:
        return "MINIMAX_M3_SPARSE"

    @staticmethod
    def get_impl_cls() -> type["MiniMaxM3SparseImpl"]:
        # Concrete impl chosen by select_main_impl_cls; base for introspection.
        return MiniMaxM3SparseImpl

    @staticmethod
    def get_builder_cls() -> type["MiniMaxM3SparseMetadataBuilder"]:
        return MiniMaxM3SparseMetadataBuilder

    @classmethod
    def get_supported_head_sizes(cls) -> list[int]:
        return [128]

    @staticmethod
    def get_supported_kernel_block_sizes() -> list[int | MultipleOf]:
        # Page size == sparse block size (one sparse block per KV page).
        return [128]

    @classmethod
    def is_sparse(cls) -> bool:
        return True

    @staticmethod
    def get_kv_cache_shape(
        num_blocks: int,
        block_size: int,
        num_kv_heads: int,
        head_size: int,
        cache_dtype_str: str = "auto",
    ) -> tuple[int, ...]:
        return (num_blocks, 2, block_size, num_kv_heads, head_size)

    @staticmethod
    def get_kv_cache_stride_order(
        include_num_layers_dimension: bool = False,
    ) -> tuple[int, ...]:
        # Permutation from get_kv_cache_shape to the actual memory layout.
        if include_num_layers_dimension:
            raise NotImplementedError  # no cross-layer KV blocks in M3
        cache_layout = get_kv_cache_layout()
        if cache_layout == "NHD":
            stride_order = (0, 1, 2, 3, 4)
        elif cache_layout == "HND":
            stride_order = (0, 1, 3, 2, 4)
        else:
            raise ValueError(f"Unknown cache layout format {cache_layout}.")
        return stride_order

MiniMaxM3SparseDecodeMetadata dataclass

Per-decode state (cudagraph-safe). decode_query_len is the uniform per-request query length (1, or 1 + num_speculative_tokens).

Source code in vllm/models/minimax_m3/common/sparse_attention.py
@dataclass
class MiniMaxM3SparseDecodeMetadata:
    """Per-decode state (cudagraph-safe). ``decode_query_len`` is the uniform
    per-request query length (1, or 1 + num_speculative_tokens)."""

    seq_lens: torch.Tensor  # [num_decodes] int32
    block_table: torch.Tensor
    decode_query_len: int

MiniMaxM3SparseImpl

Bases: AttentionImplBase[MiniMaxM3SparseMetadata]

Abstract base for block-sparse GQA over the indexer-selected blocks.

Inherits AttentionImplBase for a custom forward signature (the layer pre-inserts K/V and runs the indexer, which writes the selected blocks into the shared layer.topk_indices_buffer; the attend reads them back from there). The Triton and MSA subclasses each own a full forward -- no shared forward code.

Methods:

  • forward

    Attend the queries to the indexer-selected blocks. Per kernel.

Source code in vllm/models/minimax_m3/common/sparse_attention.py
class MiniMaxM3SparseImpl(AttentionImplBase[MiniMaxM3SparseMetadata]):
    """Abstract base for block-sparse GQA over the indexer-selected blocks.

    Inherits ``AttentionImplBase`` for a custom forward signature (the layer
    pre-inserts K/V and runs the indexer, which writes the selected blocks into
    the shared ``layer.topk_indices_buffer``; the attend reads them back from
    there). The Triton and MSA subclasses each own a full ``forward`` -- no
    shared forward code.
    """

    def __init__(
        self,
        num_heads: int,
        head_size: int,
        scale: float,
        num_kv_heads: int | None = None,
        kv_cache_dtype: str = "auto",
        *,
        topk_blocks: int,
        sparse_block_size: int,
    ) -> None:
        self.num_heads = num_heads
        self.head_size = head_size
        self.scale = scale
        self.num_kv_heads = num_kv_heads if num_kv_heads is not None else num_heads
        self.kv_cache_dtype = kv_cache_dtype
        self.use_fp8_kv = is_quantized_kv_cache(kv_cache_dtype)
        if "e5m2" in kv_cache_dtype:
            self.kv_cache_fp8_dtype = (
                torch.float8_e5m2fnuz
                if current_platform.is_fp8_fnuz()
                else torch.float8_e5m2
            )
        else:
            self.kv_cache_fp8_dtype = current_platform.fp8_dtype()
        # Sparse selection parameters (block_size == page size == SPARSE_BLOCK_SIZE).
        self.topk_blocks = topk_blocks
        self.block_size = sparse_block_size

    def forward(
        self,
        layer: AttentionLayer,
        query: torch.Tensor,
        kv_cache: torch.Tensor,
        output: torch.Tensor,
    ) -> torch.Tensor:
        """Attend the queries to the indexer-selected blocks. Per kernel.

        The indexer has already written the top-k block ids into
        ``layer.topk_indices_buffer`` (decode at ``[:, :nd]``, prefill at
        ``[:, nd:num_tokens]``); the attend reads them from there.
        """
        raise NotImplementedError

forward(layer, query, kv_cache, output)

Attend the queries to the indexer-selected blocks. Per kernel.

The indexer has already written the top-k block ids into layer.topk_indices_buffer (decode at [:, :nd], prefill at [:, nd:num_tokens]); the attend reads them from there.

Source code in vllm/models/minimax_m3/common/sparse_attention.py
def forward(
    self,
    layer: AttentionLayer,
    query: torch.Tensor,
    kv_cache: torch.Tensor,
    output: torch.Tensor,
) -> torch.Tensor:
    """Attend the queries to the indexer-selected blocks. Per kernel.

    The indexer has already written the top-k block ids into
    ``layer.topk_indices_buffer`` (decode at ``[:, :nd]``, prefill at
    ``[:, nd:num_tokens]``); the attend reads them from there.
    """
    raise NotImplementedError

MiniMaxM3SparseMetadata dataclass

Bases: AttentionMetadata

Sparse-attention metadata, split into prefill and decode sub-metadata.

Source code in vllm/models/minimax_m3/common/sparse_attention.py
@dataclass
class MiniMaxM3SparseMetadata(AttentionMetadata):
    """Sparse-attention metadata, split into prefill and decode sub-metadata."""

    seq_lens: torch.Tensor
    max_seq_len: int
    slot_mapping: torch.Tensor

    num_actual_tokens: int  # total query tokens (decode-first batch)

    # Split counts (batch reordered decode-first).
    num_decodes: int
    num_decode_tokens: int
    num_prefills: int
    num_prefill_tokens: int

    prefill: MiniMaxM3SparsePrefillMetadata | None = None
    decode: MiniMaxM3SparseDecodeMetadata | None = None

MiniMaxM3SparsePrefillMetadata dataclass

Per-prefill state; cu_seqlens_k/total_kv_blocks feed the MSA CSR.

Source code in vllm/models/minimax_m3/common/sparse_attention.py
@dataclass
class MiniMaxM3SparsePrefillMetadata:
    """Per-prefill state; ``cu_seqlens_k``/``total_kv_blocks`` feed the MSA CSR."""

    cu_seqlens_q: torch.Tensor  # [num_prefills + 1] int32, rebased to 0
    cu_seqlens_k: torch.Tensor  # [num_prefills + 1] int32, cumulative KV lengths
    seq_lens: torch.Tensor  # [num_prefills] int32, total KV lengths
    context_lens: torch.Tensor  # [num_prefills] int32 (cached/context tokens)
    block_table: torch.Tensor
    max_query_len: int
    max_seq_len: int
    total_kv_blocks: int

MiniMaxM3SparseTritonImpl

Bases: MiniMaxM3SparseImpl

Triton block-sparse attend (minimax_m3_sparse_attn) + Triton decode.

Source code in vllm/models/minimax_m3/common/sparse_attention.py
class MiniMaxM3SparseTritonImpl(MiniMaxM3SparseImpl):
    """Triton block-sparse attend (``minimax_m3_sparse_attn``) + Triton decode."""

    def forward(
        self,
        layer: AttentionLayer,
        query: torch.Tensor,
        kv_cache: torch.Tensor,
        output: torch.Tensor,
    ) -> torch.Tensor:
        attn_metadata = get_forward_context().attn_metadata
        if not isinstance(attn_metadata, dict):
            return output  # profiling run; caches unbound
        main_md = attn_metadata[layer.layer_name]  # type: ignore[attr-defined]
        assert isinstance(main_md, MiniMaxM3SparseMetadata)

        nd = main_md.num_decode_tokens
        num_tokens = main_md.num_actual_tokens
        # Indexer top-k from the shared buffer: decode [:, :nd], prefill [:, nd:].
        topk = layer.topk_indices_buffer  # type: ignore[attr-defined]
        assert topk is not None
        hd = self.head_size
        q = query[:num_tokens].view(-1, self.num_heads, hd)
        out = output[:num_tokens].view(-1, self.num_heads, hd)
        kv_cache = (
            kv_cache.view(self.kv_cache_fp8_dtype) if self.use_fp8_kv else kv_cache
        )

        # Decode [:nd]: split-K over the selected blocks (request-major chunks).
        if main_md.num_decodes > 0:
            d = main_md.decode
            assert d is not None
            minimax_m3_sparse_attn_decode(
                q[:nd],
                kv_cache,
                topk[:, :nd, :],
                d.block_table,
                d.seq_lens,
                self.num_kv_heads,
                self.scale,
                out[:nd],
                d.decode_query_len,
            )

        # Prefill [nd:]: cu_seqlens_q already rebased to 0.
        if main_md.num_prefills > 0:
            p = main_md.prefill
            assert p is not None
            minimax_m3_sparse_attn(
                q[nd:],
                kv_cache,
                topk[:, nd:num_tokens, :],
                p.block_table,
                p.cu_seqlens_q,
                p.seq_lens,
                p.context_lens,
                p.max_query_len,
                self.num_kv_heads,
                self.scale,
                out[nd:],
            )
        return output

select_main_impl_cls(*, topk_blocks, kv_cache_dtype)

Pick the main attend impl off the main KV-cache dtype.

Blackwell (SM100) uses the MSA attend for supported top-k block counts when the KV cache is BF16 or FP8 E4M3; non-Blackwell and FP8 E5M2 fall back to Triton. The MSA module is imported lazily so AMD/non-SM100 never import fmha_sm100.

Source code in vllm/models/minimax_m3/common/sparse_attention.py
def select_main_impl_cls(
    *,
    topk_blocks: int,
    kv_cache_dtype: str,
) -> type[MiniMaxM3SparseImpl]:
    """Pick the main attend impl off the main KV-cache dtype.

    Blackwell (SM100) uses the MSA attend for supported top-k block counts
    when the KV cache is BF16 or FP8 E4M3; non-Blackwell and FP8 E5M2 fall
    back to Triton. The MSA module is imported lazily so AMD/non-SM100 never
    import fmha_sm100.
    """
    use_msa = (
        current_platform.is_cuda()
        and current_platform.is_device_capability_family(100)
        and topk_blocks in (4, 8, 16, 32)
        and kv_cache_dtype != "fp8_e5m2"
    )
    selected = "MSA" if use_msa else "Triton"
    logger.info_once(
        "MiniMax M3 sparse attention selected %s (kv_cache_dtype=%s, topk_blocks=%s)",
        selected,
        kv_cache_dtype,
        topk_blocks,
    )
    if use_msa:
        from vllm.models.minimax_m3.nvidia.sparse_attention_msa import (
            MiniMaxM3SparseMSAImpl,
        )

        return MiniMaxM3SparseMSAImpl
    return MiniMaxM3SparseTritonImpl