Skip to content

vllm.models.deepseek_v4.nvidia.flashinfer_sparse

DeepSeek V4 FlashInfer TRTLLM-gen sparse MLA backend.

Uses FlashInfer's public trtllm_batch_decode_sparse_mla_dsv4 launcher with a contiguous bf16 / per-tensor FP8 KV cache. Shares the V4 sparse-index pipeline (SWA cache + compressor + indexer, 256-token blocks, head_size 512) with the FlashMLA V4 backend; only the attention forward differs.

DeepseekV4FlashInferMLASparseBackend

Bases: DeepseekV4FlashMLASparseBackend

Shares the FlashMLA V4 metadata/cache pipeline; swaps the attention impl.

Inheriting from the FlashMLA V4 backend reuses its FlashMLASparseMetadata builder (which the V4 sparse-index pipeline needs — the V3.2 FlashInfer builder lacks the c128a_* fields), 256-token blocks, head_size 512, and the contiguous (num_blocks, block_size, 512) cache shape for non-fp8_ds_mla dtypes.

Source code in vllm/models/deepseek_v4/nvidia/flashinfer_sparse.py
class DeepseekV4FlashInferMLASparseBackend(DeepseekV4FlashMLASparseBackend):
    """Shares the FlashMLA V4 metadata/cache pipeline; swaps the attention impl.

    Inheriting from the FlashMLA V4 backend reuses its ``FlashMLASparseMetadata``
    builder (which the V4 sparse-index pipeline needs — the V3.2 FlashInfer
    builder lacks the ``c128a_*`` fields), 256-token blocks, head_size 512, and
    the contiguous (num_blocks, block_size, 512) cache shape for non-``fp8_ds_mla``
    dtypes.
    """

    @staticmethod
    def get_name() -> str:
        return "FLASHINFER_MLA_SPARSE_DSV4"

    @staticmethod
    def get_impl_cls() -> type["DeepseekV4FlashInferMLASparseImpl"]:
        return DeepseekV4FlashInferMLASparseImpl

DeepseekV4FlashInferMLASparseImpl

Bases: DeepseekV4SparseMLAAttentionImpl

FlashInfer TRTLLM-gen sparse MLA implementation for DeepSeek V4.

Source code in vllm/models/deepseek_v4/nvidia/flashinfer_sparse.py
class DeepseekV4FlashInferMLASparseImpl(DeepseekV4SparseMLAAttentionImpl):
    """FlashInfer TRTLLM-gen sparse MLA implementation for DeepSeek V4."""

    backend_cls = DeepseekV4FlashInferMLASparseBackend

    @classmethod
    def get_padded_num_q_heads(cls, num_heads: int) -> int:
        # FP8 decode kernel only supports h_q = 64 or 128.
        if num_heads > 128:
            raise ValueError(
                f"DeepseekV4 Flashinfer MLA Sparse does not support {num_heads} heads "
                "(FP8 decode kernel requires h_q in {64, 128})."
            )
        return 64 if num_heads <= 64 else 128

    @classmethod
    def init_layer_buffers(cls, layer: "DeepseekV4MLAAttention") -> None:
        # Per-tensor FP8 scale buffers + precomputed scalar BMM scales. Only the
        # per-tensor FP8 cache path consumes these; bf16 reads ``layer.scale``.
        if layer.kv_cache_torch_dtype != torch.float8_e4m3fn:
            return
        # TODO: load real per-tensor Q/KV scales from the checkpoint; unit
        # scales until the scale tensor names are wired.
        fp8_q_scale = 1.0
        fp8_kv_scale = 1.0
        layer.register_buffer(
            "_flashinfer_fp8_q_scale",
            torch.tensor([fp8_q_scale], dtype=torch.float32),
            persistent=False,
        )
        layer.register_buffer(
            "_flashinfer_fp8_q_scale_inv",
            torch.tensor([1.0 / fp8_q_scale], dtype=torch.float32),
            persistent=False,
        )
        layer.register_buffer(
            "_flashinfer_fp8_kv_scale",
            torch.tensor([fp8_kv_scale], dtype=torch.float32),
            persistent=False,
        )
        # TRTLLM-gen takes scalar scale args on a distinct (correct) C++ path
        # vs 1-elem tensors, so these are Python floats. bmm1 folds the softmax
        # scale and the Q/KV per-tensor scales; bmm2 is the KV scale.
        layer._flashinfer_fp8_bmm1_scale = layer.scale * fp8_q_scale * fp8_kv_scale
        layer._flashinfer_fp8_bmm2_scale = fp8_kv_scale

    @classmethod
    def forward_mqa(  # type: ignore[override]
        cls,
        layer: "DeepseekV4MLAAttention",
        q: torch.Tensor,
        kv: torch.Tensor,
        positions: torch.Tensor,
        output: torch.Tensor,
    ) -> None:
        # The TRTLLM-gen kernel requires h_q in {64, 128}, so the output buffer
        # is allocated at the padded head count while q arrives at the local
        # head count; _forward pads q to match before the launcher.
        assert output.shape[0] == q.shape[0] and output.shape[-1] == q.shape[-1], (
            f"output buffer shape {output.shape} incompatible with q shape {q.shape}"
        )
        assert output.shape[1] >= q.shape[1], (
            f"output heads {output.shape[1]} must be >= q heads {q.shape[1]}"
        )
        # Per-tensor FP8 q produces a bf16 attention output.
        expected_output_dtype = (
            torch.bfloat16 if q.dtype == torch.float8_e4m3fn else q.dtype
        )
        assert output.dtype == expected_output_dtype, (
            f"output dtype {output.dtype} must match expected {expected_output_dtype} "
            f"for q dtype {q.dtype}"
        )

        forward_context = get_forward_context()
        attn_metadata = forward_context.attn_metadata
        if attn_metadata is None:
            # Warmup dummy run: FlashInfer reads the cache directly and lazily
            # allocates its workspace, so nothing to reserve here.
            output.zero_()
            return

        assert isinstance(attn_metadata, dict)
        flashmla_metadata = cast(
            FlashMLASparseMetadata | None, attn_metadata.get(layer.prefix)
        )
        swa_metadata = cast(
            "DeepseekSparseSWAMetadata | None",
            attn_metadata.get(layer.swa_cache_layer.prefix),
        )
        assert swa_metadata is not None

        swa_only = layer.compress_ratio <= 1
        # SWA-only layers don't allocate their own compressed KV cache.
        self_kv_cache = layer.kv_cache if not swa_only else None
        swa_kv_cache = layer.swa_cache_layer.kv_cache

        cls._forward(
            layer=layer,
            q=q,
            kv_cache=self_kv_cache,
            swa_k_cache=swa_kv_cache,
            swa_metadata=swa_metadata,
            attn_metadata=flashmla_metadata,
            swa_only=swa_only,
            output=output,
        )

    @classmethod
    def _build_sparse_index_metadata(
        cls,
        layer: "DeepseekV4MLAAttention",
        kv_cache: torch.Tensor | None,
        swa_k_cache: torch.Tensor,
        swa_metadata: "DeepseekSparseSWAMetadata",
        attn_metadata: FlashMLASparseMetadata | None,
        swa_only: bool,
    ) -> tuple[torch.Tensor, torch.Tensor, torch.Tensor, torch.Tensor]:
        """Build the combined sparse-index tensors for the mixed batch.

        Returns ``(compressed_kv_cache, seq_lens, sparse_indices,
        sparse_topk_lens)``.
        """
        num_decodes = swa_metadata.num_decodes
        num_prefills = swa_metadata.num_prefills
        num_decode_tokens = swa_metadata.num_decode_tokens
        num_prefill_tokens = swa_metadata.num_prefill_tokens
        num_reqs = num_decodes + num_prefills
        num_tokens = num_decode_tokens + num_prefill_tokens

        assert swa_metadata.seq_lens is not None
        assert swa_metadata.query_start_loc is not None
        assert swa_metadata.token_to_req_indices is not None
        assert swa_metadata.decode_swa_indices is not None
        assert swa_metadata.block_table is not None

        decode_swa_indices = swa_metadata.decode_swa_indices.reshape(
            num_decode_tokens, layer.window_size
        )
        decode_compressed_topk_lens = None
        decode_compressed_indices_are_local = False
        decode_is_valid_token = None

        if swa_only:
            assert layer.topk_indices_buffer is not None
            compressed_kv_cache = swa_k_cache
            decode_compressed_indices = None
            prefill_topk_indices = layer.topk_indices_buffer[
                num_decode_tokens:num_tokens, :0
            ]
            compressed_block_table = None
            compressed_block_size = swa_metadata.block_size
            top_k = 0
        else:
            assert kv_cache is not None
            assert attn_metadata is not None
            compressed_kv_cache = kv_cache
            compressed_block_table = attn_metadata.block_table[:num_reqs]
            compressed_block_size = attn_metadata.block_size // layer.compress_ratio

            if layer.compress_ratio == 4:
                assert layer.topk_indices_buffer is not None
                if num_prefill_tokens > 0:
                    prefill_topk_indices = layer.topk_indices_buffer[
                        num_decode_tokens:num_tokens
                    ]
                    top_k = prefill_topk_indices.shape[-1]
                else:
                    prefill_topk_indices = layer.topk_indices_buffer[:0, :0]
                    top_k = 0

                decode_compressed_indices_are_local = True
                assert swa_metadata.is_valid_token is not None
                decode_is_valid_token = swa_metadata.is_valid_token[:num_decode_tokens]
                if num_decode_tokens > 0:
                    decode_compressed_indices = layer.topk_indices_buffer[
                        :num_decode_tokens
                    ]
                else:
                    # Keep the logical width aligned with the mixed-batch case so
                    # pure-prefill steps reuse the same Triton specialization.
                    decode_compressed_indices = prefill_topk_indices[:0]
            else:
                if num_prefill_tokens > 0:
                    assert attn_metadata.c128a_prefill_topk_indices is not None
                    prefill_topk_indices = attn_metadata.c128a_prefill_topk_indices
                    top_k = prefill_topk_indices.shape[-1]
                else:
                    prefill_topk_indices = decode_swa_indices[:0, :0]
                    top_k = 0

                if num_decode_tokens > 0:
                    assert attn_metadata.c128a_global_decode_topk_indices is not None
                    assert attn_metadata.c128a_decode_topk_lens is not None
                    decode_compressed_indices = (
                        attn_metadata.c128a_global_decode_topk_indices.view(
                            num_decode_tokens, -1
                        )
                    )
                    decode_compressed_topk_lens = attn_metadata.c128a_decode_topk_lens
                    if num_prefill_tokens == 0:
                        prefill_topk_indices = decode_compressed_indices[:0, :0]
                else:
                    decode_compressed_indices = prefill_topk_indices[:0]
                    decode_compressed_topk_lens = swa_metadata.seq_lens[:0]

        query_start_loc = swa_metadata.query_start_loc[: num_reqs + 1]
        seq_lens = swa_metadata.seq_lens[:num_reqs]
        assert seq_lens.dtype == torch.int32
        sparse_indices, sparse_topk_lens = build_flashinfer_mixed_sparse_indices(
            decode_swa_indices,
            decode_compressed_indices,
            decode_compressed_topk_lens,
            prefill_topk_indices[:num_prefill_tokens],
            query_start_loc,
            seq_lens,
            swa_metadata.token_to_req_indices[:num_tokens],
            swa_metadata.block_table[:num_reqs],
            swa_metadata.block_size,
            compressed_block_table,
            compressed_block_size,
            layer.window_size,
            layer.compress_ratio,
            top_k,
            decode_compressed_indices_are_local=decode_compressed_indices_are_local,
            decode_is_valid_token=decode_is_valid_token,
        )
        return compressed_kv_cache, seq_lens, sparse_indices, sparse_topk_lens

    @classmethod
    def _forward(
        cls,
        layer: "DeepseekV4MLAAttention",
        q: torch.Tensor,
        kv_cache: torch.Tensor | None,
        swa_k_cache: torch.Tensor,
        swa_metadata: "DeepseekSparseSWAMetadata",
        attn_metadata: FlashMLASparseMetadata | None,
        swa_only: bool,
        output: torch.Tensor,
    ) -> None:
        assert layer.kv_cache_torch_dtype in (torch.bfloat16, torch.float8_e4m3fn)
        num_decodes = swa_metadata.num_decodes
        num_prefills = swa_metadata.num_prefills
        num_decode_tokens = swa_metadata.num_decode_tokens
        num_prefill_tokens = swa_metadata.num_prefill_tokens
        num_reqs = num_decodes + num_prefills
        num_tokens = num_decode_tokens + num_prefill_tokens
        if num_tokens == 0:
            return

        (
            compressed_kv_cache,
            seq_lens,
            sparse_indices,
            sparse_topk_lens,
        ) = cls._build_sparse_index_metadata(
            layer=layer,
            kv_cache=kv_cache,
            swa_k_cache=swa_k_cache,
            swa_metadata=swa_metadata,
            attn_metadata=attn_metadata,
            swa_only=swa_only,
        )

        # CUDA graph execution can pad q/output past the scheduled token count;
        # restrict to the real tokens (the launcher validates sparse indices).
        query = q[:num_tokens]
        output = output[:num_tokens]
        bmm1_scale: float | torch.Tensor = layer.scale
        bmm2_scale: float | torch.Tensor = 1.0
        if layer.kv_cache_torch_dtype == torch.float8_e4m3fn:
            assert query.dtype == torch.float8_e4m3fn
            bmm1_scale = layer._flashinfer_fp8_bmm1_scale
            bmm2_scale = layer._flashinfer_fp8_bmm2_scale
        else:
            assert query.dtype == torch.bfloat16
            query = query.contiguous()

        # The TRTLLM-gen sparse-MLA kernel requires h_q in {64, 128}; zero-pad
        # the query heads to the allocated output head count. Padded heads attend
        # to the shared KV and are sliced off downstream (output is padded too).
        padded_heads = output.shape[1]
        if query.shape[1] < padded_heads:
            padded_query = query.new_zeros(
                (query.shape[0], padded_heads, query.shape[2])
            )
            padded_query[:, : query.shape[1], :] = query
            query = padded_query

        workspace = _get_flashinfer_dsv4_workspace(q.device)
        query_start_loc = swa_metadata.query_start_loc
        query_start_loc_cpu = swa_metadata.query_start_loc_cpu
        assert query_start_loc is not None and query_start_loc_cpu is not None

        # Keep Perkz's two-call decode/prefill split: the TRTLLM-gen launcher is
        # tuned for uniform-q batches, and collapsing the mixed batch into a
        # single call is the suspected source of the prior IMA.
        if num_decode_tokens > 0:
            decode_cu = query_start_loc[: num_decodes + 1]
            decode_cu_cpu = query_start_loc_cpu[: num_decodes + 1]
            decode_lens_cpu = decode_cu_cpu[1:] - decode_cu_cpu[:-1]
            flashinfer_trtllm_batch_decode_sparse_mla_dsv4(
                query=query[:num_decode_tokens],
                swa_kv_cache=swa_k_cache,
                workspace_buffer=workspace,
                sparse_indices=sparse_indices[:num_decode_tokens],
                compressed_kv_cache=compressed_kv_cache,
                sparse_topk_lens=sparse_topk_lens[:num_decode_tokens],
                seq_lens=seq_lens[:num_decodes],
                out=output[:num_decode_tokens],
                bmm1_scale=bmm1_scale,
                bmm2_scale=bmm2_scale,
                sinks=layer.attn_sink,
                cum_seq_lens_q=decode_cu,
                max_q_len=int(decode_lens_cpu.max().item()),
            )

        if num_prefill_tokens > 0:
            # The prefill query view re-anchors at offset 0, so rebase the
            # cumulative query offsets to start at 0.
            prefill_cu = (
                query_start_loc[num_decodes : num_reqs + 1]
                - query_start_loc[num_decodes]
            )
            prefill_cu_cpu = query_start_loc_cpu[num_decodes : num_reqs + 1]
            prefill_lens_cpu = prefill_cu_cpu[1:] - prefill_cu_cpu[:-1]
            flashinfer_trtllm_batch_decode_sparse_mla_dsv4(
                query=query[num_decode_tokens:num_tokens],
                swa_kv_cache=swa_k_cache,
                workspace_buffer=workspace,
                sparse_indices=sparse_indices[num_decode_tokens:num_tokens],
                compressed_kv_cache=compressed_kv_cache,
                sparse_topk_lens=sparse_topk_lens[num_decode_tokens:num_tokens],
                seq_lens=seq_lens[num_decodes:num_reqs],
                out=output[num_decode_tokens:num_tokens],
                bmm1_scale=bmm1_scale,
                bmm2_scale=bmm2_scale,
                sinks=layer.attn_sink,
                cum_seq_lens_q=prefill_cu,
                max_q_len=int(prefill_lens_cpu.max().item()),
            )

_build_sparse_index_metadata classmethod

_build_sparse_index_metadata(
    layer: DeepseekV4MLAAttention,
    kv_cache: Tensor | None,
    swa_k_cache: Tensor,
    swa_metadata: DeepseekSparseSWAMetadata,
    attn_metadata: FlashMLASparseMetadata | None,
    swa_only: bool,
) -> tuple[Tensor, Tensor, Tensor, Tensor]

Build the combined sparse-index tensors for the mixed batch.

Returns (compressed_kv_cache, seq_lens, sparse_indices, sparse_topk_lens).

Source code in vllm/models/deepseek_v4/nvidia/flashinfer_sparse.py
@classmethod
def _build_sparse_index_metadata(
    cls,
    layer: "DeepseekV4MLAAttention",
    kv_cache: torch.Tensor | None,
    swa_k_cache: torch.Tensor,
    swa_metadata: "DeepseekSparseSWAMetadata",
    attn_metadata: FlashMLASparseMetadata | None,
    swa_only: bool,
) -> tuple[torch.Tensor, torch.Tensor, torch.Tensor, torch.Tensor]:
    """Build the combined sparse-index tensors for the mixed batch.

    Returns ``(compressed_kv_cache, seq_lens, sparse_indices,
    sparse_topk_lens)``.
    """
    num_decodes = swa_metadata.num_decodes
    num_prefills = swa_metadata.num_prefills
    num_decode_tokens = swa_metadata.num_decode_tokens
    num_prefill_tokens = swa_metadata.num_prefill_tokens
    num_reqs = num_decodes + num_prefills
    num_tokens = num_decode_tokens + num_prefill_tokens

    assert swa_metadata.seq_lens is not None
    assert swa_metadata.query_start_loc is not None
    assert swa_metadata.token_to_req_indices is not None
    assert swa_metadata.decode_swa_indices is not None
    assert swa_metadata.block_table is not None

    decode_swa_indices = swa_metadata.decode_swa_indices.reshape(
        num_decode_tokens, layer.window_size
    )
    decode_compressed_topk_lens = None
    decode_compressed_indices_are_local = False
    decode_is_valid_token = None

    if swa_only:
        assert layer.topk_indices_buffer is not None
        compressed_kv_cache = swa_k_cache
        decode_compressed_indices = None
        prefill_topk_indices = layer.topk_indices_buffer[
            num_decode_tokens:num_tokens, :0
        ]
        compressed_block_table = None
        compressed_block_size = swa_metadata.block_size
        top_k = 0
    else:
        assert kv_cache is not None
        assert attn_metadata is not None
        compressed_kv_cache = kv_cache
        compressed_block_table = attn_metadata.block_table[:num_reqs]
        compressed_block_size = attn_metadata.block_size // layer.compress_ratio

        if layer.compress_ratio == 4:
            assert layer.topk_indices_buffer is not None
            if num_prefill_tokens > 0:
                prefill_topk_indices = layer.topk_indices_buffer[
                    num_decode_tokens:num_tokens
                ]
                top_k = prefill_topk_indices.shape[-1]
            else:
                prefill_topk_indices = layer.topk_indices_buffer[:0, :0]
                top_k = 0

            decode_compressed_indices_are_local = True
            assert swa_metadata.is_valid_token is not None
            decode_is_valid_token = swa_metadata.is_valid_token[:num_decode_tokens]
            if num_decode_tokens > 0:
                decode_compressed_indices = layer.topk_indices_buffer[
                    :num_decode_tokens
                ]
            else:
                # Keep the logical width aligned with the mixed-batch case so
                # pure-prefill steps reuse the same Triton specialization.
                decode_compressed_indices = prefill_topk_indices[:0]
        else:
            if num_prefill_tokens > 0:
                assert attn_metadata.c128a_prefill_topk_indices is not None
                prefill_topk_indices = attn_metadata.c128a_prefill_topk_indices
                top_k = prefill_topk_indices.shape[-1]
            else:
                prefill_topk_indices = decode_swa_indices[:0, :0]
                top_k = 0

            if num_decode_tokens > 0:
                assert attn_metadata.c128a_global_decode_topk_indices is not None
                assert attn_metadata.c128a_decode_topk_lens is not None
                decode_compressed_indices = (
                    attn_metadata.c128a_global_decode_topk_indices.view(
                        num_decode_tokens, -1
                    )
                )
                decode_compressed_topk_lens = attn_metadata.c128a_decode_topk_lens
                if num_prefill_tokens == 0:
                    prefill_topk_indices = decode_compressed_indices[:0, :0]
            else:
                decode_compressed_indices = prefill_topk_indices[:0]
                decode_compressed_topk_lens = swa_metadata.seq_lens[:0]

    query_start_loc = swa_metadata.query_start_loc[: num_reqs + 1]
    seq_lens = swa_metadata.seq_lens[:num_reqs]
    assert seq_lens.dtype == torch.int32
    sparse_indices, sparse_topk_lens = build_flashinfer_mixed_sparse_indices(
        decode_swa_indices,
        decode_compressed_indices,
        decode_compressed_topk_lens,
        prefill_topk_indices[:num_prefill_tokens],
        query_start_loc,
        seq_lens,
        swa_metadata.token_to_req_indices[:num_tokens],
        swa_metadata.block_table[:num_reqs],
        swa_metadata.block_size,
        compressed_block_table,
        compressed_block_size,
        layer.window_size,
        layer.compress_ratio,
        top_k,
        decode_compressed_indices_are_local=decode_compressed_indices_are_local,
        decode_is_valid_token=decode_is_valid_token,
    )
    return compressed_kv_cache, seq_lens, sparse_indices, sparse_topk_lens