class DeepseekV4FlashInferMLASparseImpl(DeepseekV4SparseMLAAttentionImpl):
"""FlashInfer TRTLLM-gen sparse MLA implementation for DeepSeek V4."""
backend_cls = DeepseekV4FlashInferMLASparseBackend
@classmethod
def get_padded_num_q_heads(cls, num_heads: int) -> int:
# FP8 decode kernel only supports h_q = 64 or 128.
if num_heads > 128:
raise ValueError(
f"DeepseekV4 Flashinfer MLA Sparse does not support {num_heads} heads "
"(FP8 decode kernel requires h_q in {64, 128})."
)
return 64 if num_heads <= 64 else 128
@classmethod
def init_layer_buffers(cls, layer: "DeepseekV4MLAAttention") -> None:
# Per-tensor FP8 scale buffers + precomputed scalar BMM scales. Only the
# per-tensor FP8 cache path consumes these; bf16 reads ``layer.scale``.
if layer.kv_cache_torch_dtype != torch.float8_e4m3fn:
return
# TODO: load real per-tensor Q/KV scales from the checkpoint; unit
# scales until the scale tensor names are wired.
fp8_q_scale = 1.0
fp8_kv_scale = 1.0
layer.register_buffer(
"_flashinfer_fp8_q_scale",
torch.tensor([fp8_q_scale], dtype=torch.float32),
persistent=False,
)
layer.register_buffer(
"_flashinfer_fp8_q_scale_inv",
torch.tensor([1.0 / fp8_q_scale], dtype=torch.float32),
persistent=False,
)
layer.register_buffer(
"_flashinfer_fp8_kv_scale",
torch.tensor([fp8_kv_scale], dtype=torch.float32),
persistent=False,
)
# TRTLLM-gen takes scalar scale args on a distinct (correct) C++ path
# vs 1-elem tensors, so these are Python floats. bmm1 folds the softmax
# scale and the Q/KV per-tensor scales; bmm2 is the KV scale.
layer._flashinfer_fp8_bmm1_scale = layer.scale * fp8_q_scale * fp8_kv_scale
layer._flashinfer_fp8_bmm2_scale = fp8_kv_scale
@classmethod
def forward_mqa( # type: ignore[override]
cls,
layer: "DeepseekV4MLAAttention",
q: torch.Tensor,
kv: torch.Tensor,
positions: torch.Tensor,
output: torch.Tensor,
) -> None:
# The TRTLLM-gen kernel requires h_q in {64, 128}, so the output buffer
# is allocated at the padded head count while q arrives at the local
# head count; _forward pads q to match before the launcher.
assert output.shape[0] == q.shape[0] and output.shape[-1] == q.shape[-1], (
f"output buffer shape {output.shape} incompatible with q shape {q.shape}"
)
assert output.shape[1] >= q.shape[1], (
f"output heads {output.shape[1]} must be >= q heads {q.shape[1]}"
)
# Per-tensor FP8 q produces a bf16 attention output.
expected_output_dtype = (
torch.bfloat16 if q.dtype == torch.float8_e4m3fn else q.dtype
)
assert output.dtype == expected_output_dtype, (
f"output dtype {output.dtype} must match expected {expected_output_dtype} "
f"for q dtype {q.dtype}"
)
forward_context = get_forward_context()
attn_metadata = forward_context.attn_metadata
if attn_metadata is None:
# Warmup dummy run: FlashInfer reads the cache directly and lazily
# allocates its workspace, so nothing to reserve here.
output.zero_()
return
assert isinstance(attn_metadata, dict)
flashmla_metadata = cast(
FlashMLASparseMetadata | None, attn_metadata.get(layer.prefix)
)
swa_metadata = cast(
"DeepseekSparseSWAMetadata | None",
attn_metadata.get(layer.swa_cache_layer.prefix),
)
assert swa_metadata is not None
swa_only = layer.compress_ratio <= 1
# SWA-only layers don't allocate their own compressed KV cache.
self_kv_cache = layer.kv_cache if not swa_only else None
swa_kv_cache = layer.swa_cache_layer.kv_cache
cls._forward(
layer=layer,
q=q,
kv_cache=self_kv_cache,
swa_k_cache=swa_kv_cache,
swa_metadata=swa_metadata,
attn_metadata=flashmla_metadata,
swa_only=swa_only,
output=output,
)
@classmethod
def _build_sparse_index_metadata(
cls,
layer: "DeepseekV4MLAAttention",
kv_cache: torch.Tensor | None,
swa_k_cache: torch.Tensor,
swa_metadata: "DeepseekSparseSWAMetadata",
attn_metadata: FlashMLASparseMetadata | None,
swa_only: bool,
) -> tuple[torch.Tensor, torch.Tensor, torch.Tensor, torch.Tensor]:
"""Build the combined sparse-index tensors for the mixed batch.
Returns ``(compressed_kv_cache, seq_lens, sparse_indices,
sparse_topk_lens)``.
"""
num_decodes = swa_metadata.num_decodes
num_prefills = swa_metadata.num_prefills
num_decode_tokens = swa_metadata.num_decode_tokens
num_prefill_tokens = swa_metadata.num_prefill_tokens
num_reqs = num_decodes + num_prefills
num_tokens = num_decode_tokens + num_prefill_tokens
assert swa_metadata.seq_lens is not None
assert swa_metadata.query_start_loc is not None
assert swa_metadata.token_to_req_indices is not None
assert swa_metadata.decode_swa_indices is not None
assert swa_metadata.block_table is not None
decode_swa_indices = swa_metadata.decode_swa_indices.reshape(
num_decode_tokens, layer.window_size
)
decode_compressed_topk_lens = None
decode_compressed_indices_are_local = False
decode_is_valid_token = None
if swa_only:
assert layer.topk_indices_buffer is not None
compressed_kv_cache = swa_k_cache
decode_compressed_indices = None
prefill_topk_indices = layer.topk_indices_buffer[
num_decode_tokens:num_tokens, :0
]
compressed_block_table = None
compressed_block_size = swa_metadata.block_size
top_k = 0
else:
assert kv_cache is not None
assert attn_metadata is not None
compressed_kv_cache = kv_cache
compressed_block_table = attn_metadata.block_table[:num_reqs]
compressed_block_size = attn_metadata.block_size // layer.compress_ratio
if layer.compress_ratio == 4:
assert layer.topk_indices_buffer is not None
if num_prefill_tokens > 0:
prefill_topk_indices = layer.topk_indices_buffer[
num_decode_tokens:num_tokens
]
top_k = prefill_topk_indices.shape[-1]
else:
prefill_topk_indices = layer.topk_indices_buffer[:0, :0]
top_k = 0
decode_compressed_indices_are_local = True
assert swa_metadata.is_valid_token is not None
decode_is_valid_token = swa_metadata.is_valid_token[:num_decode_tokens]
if num_decode_tokens > 0:
decode_compressed_indices = layer.topk_indices_buffer[
:num_decode_tokens
]
else:
# Keep the logical width aligned with the mixed-batch case so
# pure-prefill steps reuse the same Triton specialization.
decode_compressed_indices = prefill_topk_indices[:0]
else:
if num_prefill_tokens > 0:
assert attn_metadata.c128a_prefill_topk_indices is not None
prefill_topk_indices = attn_metadata.c128a_prefill_topk_indices
top_k = prefill_topk_indices.shape[-1]
else:
prefill_topk_indices = decode_swa_indices[:0, :0]
top_k = 0
if num_decode_tokens > 0:
assert attn_metadata.c128a_global_decode_topk_indices is not None
assert attn_metadata.c128a_decode_topk_lens is not None
decode_compressed_indices = (
attn_metadata.c128a_global_decode_topk_indices.view(
num_decode_tokens, -1
)
)
decode_compressed_topk_lens = attn_metadata.c128a_decode_topk_lens
if num_prefill_tokens == 0:
prefill_topk_indices = decode_compressed_indices[:0, :0]
else:
decode_compressed_indices = prefill_topk_indices[:0]
decode_compressed_topk_lens = swa_metadata.seq_lens[:0]
query_start_loc = swa_metadata.query_start_loc[: num_reqs + 1]
seq_lens = swa_metadata.seq_lens[:num_reqs]
assert seq_lens.dtype == torch.int32
sparse_indices, sparse_topk_lens = build_flashinfer_mixed_sparse_indices(
decode_swa_indices,
decode_compressed_indices,
decode_compressed_topk_lens,
prefill_topk_indices[:num_prefill_tokens],
query_start_loc,
seq_lens,
swa_metadata.token_to_req_indices[:num_tokens],
swa_metadata.block_table[:num_reqs],
swa_metadata.block_size,
compressed_block_table,
compressed_block_size,
layer.window_size,
layer.compress_ratio,
top_k,
decode_compressed_indices_are_local=decode_compressed_indices_are_local,
decode_is_valid_token=decode_is_valid_token,
)
return compressed_kv_cache, seq_lens, sparse_indices, sparse_topk_lens
@classmethod
def _forward(
cls,
layer: "DeepseekV4MLAAttention",
q: torch.Tensor,
kv_cache: torch.Tensor | None,
swa_k_cache: torch.Tensor,
swa_metadata: "DeepseekSparseSWAMetadata",
attn_metadata: FlashMLASparseMetadata | None,
swa_only: bool,
output: torch.Tensor,
) -> None:
assert layer.kv_cache_torch_dtype in (torch.bfloat16, torch.float8_e4m3fn)
num_decodes = swa_metadata.num_decodes
num_prefills = swa_metadata.num_prefills
num_decode_tokens = swa_metadata.num_decode_tokens
num_prefill_tokens = swa_metadata.num_prefill_tokens
num_reqs = num_decodes + num_prefills
num_tokens = num_decode_tokens + num_prefill_tokens
if num_tokens == 0:
return
(
compressed_kv_cache,
seq_lens,
sparse_indices,
sparse_topk_lens,
) = cls._build_sparse_index_metadata(
layer=layer,
kv_cache=kv_cache,
swa_k_cache=swa_k_cache,
swa_metadata=swa_metadata,
attn_metadata=attn_metadata,
swa_only=swa_only,
)
# CUDA graph execution can pad q/output past the scheduled token count;
# restrict to the real tokens (the launcher validates sparse indices).
query = q[:num_tokens]
output = output[:num_tokens]
bmm1_scale: float | torch.Tensor = layer.scale
bmm2_scale: float | torch.Tensor = 1.0
if layer.kv_cache_torch_dtype == torch.float8_e4m3fn:
assert query.dtype == torch.float8_e4m3fn
bmm1_scale = layer._flashinfer_fp8_bmm1_scale
bmm2_scale = layer._flashinfer_fp8_bmm2_scale
else:
assert query.dtype == torch.bfloat16
query = query.contiguous()
# The TRTLLM-gen sparse-MLA kernel requires h_q in {64, 128}; zero-pad
# the query heads to the allocated output head count. Padded heads attend
# to the shared KV and are sliced off downstream (output is padded too).
padded_heads = output.shape[1]
if query.shape[1] < padded_heads:
padded_query = query.new_zeros(
(query.shape[0], padded_heads, query.shape[2])
)
padded_query[:, : query.shape[1], :] = query
query = padded_query
workspace = _get_flashinfer_dsv4_workspace(q.device)
query_start_loc = swa_metadata.query_start_loc
query_start_loc_cpu = swa_metadata.query_start_loc_cpu
assert query_start_loc is not None and query_start_loc_cpu is not None
# Keep Perkz's two-call decode/prefill split: the TRTLLM-gen launcher is
# tuned for uniform-q batches, and collapsing the mixed batch into a
# single call is the suspected source of the prior IMA.
if num_decode_tokens > 0:
decode_cu = query_start_loc[: num_decodes + 1]
decode_cu_cpu = query_start_loc_cpu[: num_decodes + 1]
decode_lens_cpu = decode_cu_cpu[1:] - decode_cu_cpu[:-1]
flashinfer_trtllm_batch_decode_sparse_mla_dsv4(
query=query[:num_decode_tokens],
swa_kv_cache=swa_k_cache,
workspace_buffer=workspace,
sparse_indices=sparse_indices[:num_decode_tokens],
compressed_kv_cache=compressed_kv_cache,
sparse_topk_lens=sparse_topk_lens[:num_decode_tokens],
seq_lens=seq_lens[:num_decodes],
out=output[:num_decode_tokens],
bmm1_scale=bmm1_scale,
bmm2_scale=bmm2_scale,
sinks=layer.attn_sink,
cum_seq_lens_q=decode_cu,
max_q_len=int(decode_lens_cpu.max().item()),
)
if num_prefill_tokens > 0:
# The prefill query view re-anchors at offset 0, so rebase the
# cumulative query offsets to start at 0.
prefill_cu = (
query_start_loc[num_decodes : num_reqs + 1]
- query_start_loc[num_decodes]
)
prefill_cu_cpu = query_start_loc_cpu[num_decodes : num_reqs + 1]
prefill_lens_cpu = prefill_cu_cpu[1:] - prefill_cu_cpu[:-1]
flashinfer_trtllm_batch_decode_sparse_mla_dsv4(
query=query[num_decode_tokens:num_tokens],
swa_kv_cache=swa_k_cache,
workspace_buffer=workspace,
sparse_indices=sparse_indices[num_decode_tokens:num_tokens],
compressed_kv_cache=compressed_kv_cache,
sparse_topk_lens=sparse_topk_lens[num_decode_tokens:num_tokens],
seq_lens=seq_lens[num_decodes:num_reqs],
out=output[num_decode_tokens:num_tokens],
bmm1_scale=bmm1_scale,
bmm2_scale=bmm2_scale,
sinks=layer.attn_sink,
cum_seq_lens_q=prefill_cu,
max_q_len=int(prefill_lens_cpu.max().item()),
)