vllm.models.minimax_m3.common.sparse_attention ¶
Main block-sparse GQA attention for MiniMax M3 sparse layers.
The lightning indexer (indexer.py) selects the top-k KV blocks (written into the shared layer.topk_indices_buffer); this module holds the main attention that attends only to those blocks: the paged K/V cache backend, its metadata + builder, and the impl that reads the indexer's top-k from that buffer. The Triton attend kernel lives here; the SM100 (MSA) build_k2q_csr + sparse_atten_func attend lives in nvidia/sparse_attention_msa.py.
MiniMaxM3SparseBackend and MiniMaxM3SparseMetadata are referenced by the attention-backend registry (by dotted path) and by spec-decode, so they must keep these names and stay in this module.
Classes:
-
MiniMaxM3SparseBackend–Block-sparse GQA backend for MiniMax M3 sparse attention layers.
-
MiniMaxM3SparseDecodeMetadata–Per-decode state (cudagraph-safe).
decode_query_lenis the uniform -
MiniMaxM3SparseImpl–Abstract base for block-sparse GQA over the indexer-selected blocks.
-
MiniMaxM3SparseMetadata–Sparse-attention metadata, split into prefill and decode sub-metadata.
-
MiniMaxM3SparsePrefillMetadata–Per-prefill state;
cu_seqlens_k/total_kv_blocksfeed the MSA CSR. -
MiniMaxM3SparseTritonImpl–Triton block-sparse attend (
minimax_m3_sparse_attn) + Triton decode.
Functions:
-
select_main_impl_cls–Pick the main attend impl off the main KV-cache dtype.
MiniMaxM3SparseBackend ¶
Bases: AttentionBackend
Block-sparse GQA backend for MiniMax M3 sparse attention layers.
Source code in vllm/models/minimax_m3/common/sparse_attention.py
MiniMaxM3SparseDecodeMetadata dataclass ¶
Per-decode state (cudagraph-safe). decode_query_len is the uniform per-request query length (1, or 1 + num_speculative_tokens).
Source code in vllm/models/minimax_m3/common/sparse_attention.py
MiniMaxM3SparseImpl ¶
Bases: AttentionImplBase[MiniMaxM3SparseMetadata]
Abstract base for block-sparse GQA over the indexer-selected blocks.
Inherits AttentionImplBase for a custom forward signature (the layer pre-inserts K/V and runs the indexer, which writes the selected blocks into the shared layer.topk_indices_buffer; the attend reads them back from there). The Triton and MSA subclasses each own a full forward -- no shared forward code.
Methods:
-
forward–Attend the queries to the indexer-selected blocks. Per kernel.
Source code in vllm/models/minimax_m3/common/sparse_attention.py
forward(layer, query, kv_cache, output) ¶
Attend the queries to the indexer-selected blocks. Per kernel.
The indexer has already written the top-k block ids into layer.topk_indices_buffer (decode at [:, :nd], prefill at [:, nd:num_tokens]); the attend reads them from there.
Source code in vllm/models/minimax_m3/common/sparse_attention.py
MiniMaxM3SparseMetadata dataclass ¶
Bases: AttentionMetadata
Sparse-attention metadata, split into prefill and decode sub-metadata.
Source code in vllm/models/minimax_m3/common/sparse_attention.py
MiniMaxM3SparsePrefillMetadata dataclass ¶
Per-prefill state; cu_seqlens_k/total_kv_blocks feed the MSA CSR.
Source code in vllm/models/minimax_m3/common/sparse_attention.py
MiniMaxM3SparseTritonImpl ¶
Bases: MiniMaxM3SparseImpl
Triton block-sparse attend (minimax_m3_sparse_attn) + Triton decode.
Source code in vllm/models/minimax_m3/common/sparse_attention.py
select_main_impl_cls(*, topk_blocks, kv_cache_dtype) ¶
Pick the main attend impl off the main KV-cache dtype.
Blackwell (SM100) uses the MSA attend for supported top-k block counts when the KV cache is BF16 or FP8 E4M3; non-Blackwell and FP8 E5M2 fall back to Triton. The MSA module is imported lazily so AMD/non-SM100 never import fmha_sm100.