class MiniMaxM3SparseMSAImpl(MiniMaxM3SparseImpl):
"""MSA block-sparse attend (``fmha_sm100``); Triton split-K decode."""
def forward(
self,
layer: AttentionLayer,
query: torch.Tensor,
kv_cache: torch.Tensor,
topk_idx: tuple[torch.Tensor | None, torch.Tensor | None],
output: torch.Tensor,
) -> torch.Tensor:
attn_metadata = get_forward_context().attn_metadata
if not isinstance(attn_metadata, dict):
return output # profiling run; caches unbound
main_md = attn_metadata[layer.layer_name] # type: ignore[attr-defined]
assert isinstance(main_md, MiniMaxM3SparseMetadata)
decode_topk, prefill_topk = topk_idx
nd = main_md.num_decode_tokens
num_tokens = main_md.num_actual_tokens
hd = self.head_size
q = query[:num_tokens].view(-1, self.num_heads, hd)
out = output[:num_tokens].view(-1, self.num_heads, hd)
kv_cache = (
kv_cache.view(self.kv_cache_fp8_dtype) if self.use_fp8_kv else kv_cache
)
# Decode [:nd]: Triton split-K placeholder (no MSA decode yet).
if main_md.num_decodes > 0:
d = main_md.decode
assert d is not None and decode_topk is not None
minimax_m3_sparse_attn_decode(
q[:nd],
kv_cache,
decode_topk,
d.block_table,
d.seq_lens,
self.num_kv_heads,
self.scale,
out[:nd],
d.decode_query_len,
)
# Prefill [nd:]: MSA sparse FMHA over the selected blocks.
if main_md.num_prefills > 0:
from vllm.third_party.fmha_sm100.sparse import (
build_k2q_csr,
sparse_atten_func,
)
p = main_md.prefill
assert p is not None and prefill_topk is not None
qp = q[nd:]
k_cache = kv_cache[:, 0].transpose(1, 2)
v_cache = kv_cache[:, 1].transpose(1, 2)
k2q_row_ptr, k2q_q_indices, schedule = build_k2q_csr(
prefill_topk,
p.cu_seqlens_q,
p.cu_seqlens_k,
SPARSE_BLOCK_SIZE,
total_k=0,
max_seqlen_k=p.max_seq_len,
max_seqlen_q=p.max_query_len,
total_rows=p.total_kv_blocks,
qhead_per_kv=qp.shape[1] // self.num_kv_heads,
return_schedule=True,
)
sparse_atten_func(
qp,
k_cache,
v_cache,
k2q_row_ptr,
k2q_q_indices,
topK=self.topk_blocks,
blk_kv=SPARSE_BLOCK_SIZE,
causal=True,
softmax_scale=self.scale,
cu_seqlens_q=p.cu_seqlens_q,
cu_seqlens_k=p.cu_seqlens_k,
max_seqlen_q=p.max_query_len,
max_seqlen_k=p.max_seq_len,
page_table=p.block_table,
seqused_k=p.seq_lens,
schedule=schedule,
out=out[nd:],
)
return output