Skip to content

vllm.v1.worker.gpu.model_states.mamba_hybrid

Classes:

MambaHybridModelState

Bases: DefaultModelState

Model state for hybrid attention + Mamba / linear-attention models.

Methods:

  • preprocess_state

    Migrate each request's mamba state across block boundaries before the

Source code in vllm/v1/worker/gpu/model_states/mamba_hybrid.py
class MambaHybridModelState(DefaultModelState):
    """Model state for hybrid attention + Mamba / linear-attention models."""

    def __init__(
        self,
        vllm_config: VllmConfig,
        model: nn.Module,
        encoder_cache: EncoderCache | None,
        device: torch.device,
    ) -> None:
        super().__init__(vllm_config, model, encoder_cache, device)
        self.cache_config = vllm_config.cache_config
        self.num_accepted_tokens_gpu = torch.ones(
            self.max_num_reqs, dtype=torch.int32, device=self.device
        )
        # Pre-copy "align" prefix-cache state (V2). The migration of each
        # request's mamba state across block boundaries runs as a fused GPU
        # kernel reusing the postprocess copy machinery, so the per-step src
        # columns and the running state_idx are kept GPU-resident.
        self._align_mode = self.cache_config.mamba_cache_mode == "align"
        if self._align_mode:
            self._mamba_state_idx_gpu = torch.zeros(
                self.max_num_reqs, dtype=torch.int32, device=self.device
            )
            self._mamba_src_col_gpu = torch.full(
                (self.max_num_reqs,), -1, dtype=torch.int32, device=self.device
            )
            self._mamba_src_off_gpu = torch.zeros(
                self.max_num_reqs, dtype=torch.int32, device=self.device
            )
            self._mamba_ctx: MambaSpecDecodeGPUContext | None = None
            self._mamba_group_ids: list[int] = []
            self._mamba_spec: MambaSpec | None = None

    def add_request(self, req_index: int, new_req_data: NewRequestData) -> None:
        super().add_request(req_index, new_req_data)
        if self._align_mode:
            # Seed the running state block from the resumed/prefilled position.
            self._mamba_state_idx_gpu[req_index] = (
                new_req_data.num_computed_tokens - 1
            ) // self.cache_config.block_size
            self.num_accepted_tokens_gpu[req_index] = 1

    def _get_mamba_group_info(
        self, kv_cache_config: KVCacheConfig
    ) -> tuple[list[int], MambaSpec]:
        if self._mamba_spec is None:
            group_ids: list[int] = []
            specs: list[MambaSpec] = []
            for i, group in enumerate(kv_cache_config.kv_cache_groups):
                spec = group.kv_cache_spec
                if isinstance(spec, MambaSpec):
                    group_ids.append(i)
                    specs.append(spec)
            assert specs, "no mamba layers in the model"
            assert all(specs[0] == s for s in specs)
            self._mamba_group_ids = group_ids
            self._mamba_spec = specs[0]
        return self._mamba_group_ids, self._mamba_spec

    def _ensure_align_ctx(
        self,
        kv_cache_config: KVCacheConfig,
        mamba_group_ids: list[int],
        block_tables: tuple[torch.Tensor, ...],
    ) -> MambaSpecDecodeGPUContext:
        if self._mamba_ctx is None:
            copy_funcs = self.model.get_mamba_state_copy_func()
            # The fused copy kernels shift conv windows assuming the SD layout;
            # the DS layout cannot express a >0 spec-decode shift as a single
            # contiguous copy (mirrors get_conv_copy_spec's NotImplementedError).
            if get_conv_copy_spec in copy_funcs and is_conv_state_dim_first():
                assert self.vllm_config.speculative_config is None, (
                    "DS conv state layout does not support mamba align state "
                    "copies with speculative decoding"
                )
            self._mamba_ctx = MambaSpecDecodeGPUContext.create(
                max_num_reqs=self.max_num_reqs,
                kv_cache_config=kv_cache_config,
                num_state_types=len(copy_funcs),
                device=self.device,
                make_buffer=lambda n, dtype: CpuGpuBuffer(
                    n, dtype=dtype, device=self.device
                ),
            )
        ctx = self._mamba_ctx
        if not ctx.is_initialized:
            forward_context = self.vllm_config.compilation_config.static_forward_context
            # block_tables are batch-order slices of the persistent
            # input_block_tables (stable data_ptr), so the metadata is captured
            # once here and reused across steps.
            ctx.initialize_from_forward_context(
                kv_cache_config,
                forward_context,
                self.model.get_mamba_state_copy_func(),
                [block_tables[gid] for gid in mamba_group_ids],
            )
        return ctx

    def preprocess_state(
        self,
        input_batch: InputBatch,
        block_tables: tuple[torch.Tensor, ...],
        kv_cache_config: KVCacheConfig,
        num_computed_tokens: torch.Tensor,
    ) -> None:
        """Migrate each request's mamba state across block boundaries before the
        forward (V1 align semantics, done on GPU). Runs on real batches only
        (dummy DP/profiling runs skip preprocess_state), and before
        ``prepare_attn`` gathers ``num_accepted_tokens``, so the boundary reset
        is visible to the forward kernels.
        """
        if not self._align_mode:
            return
        num_reqs = input_batch.num_reqs
        if num_reqs == 0:
            return
        mamba_group_ids, mamba_spec = self._get_mamba_group_info(kv_cache_config)
        ctx = self._ensure_align_ctx(kv_cache_config, mamba_group_ids, block_tables)

        # The state-advance + pre-copy kernels run every step; they fast-exit per
        # request when src_col < 0 or src_col == dst_col, so no copy happens on
        # steps that don't cross a block boundary. (Skipping the launch entirely
        # would need a V1-style async-D2H of the actual num_computed, since
        # num_computed_tokens_np is an optimistic mirror under async scheduling;
        # the launch cost is ~0.3% of TPOT, so the GPU fast-exit suffices.)
        block = 256
        grid = (triton.cdiv(num_reqs, block),)
        preprocess_mamba_align_fused_kernel[grid](
            input_batch.idx_mapping,
            self._mamba_state_idx_gpu,
            num_computed_tokens,
            input_batch.query_start_loc,
            self.num_accepted_tokens_gpu,
            self._mamba_src_col_gpu,
            self._mamba_src_off_gpu,
            num_reqs,
            BLOCK_SIZE=block,
            MAMBA_BLOCK_SIZE=mamba_spec.block_size,
        )
        ctx.run_fused_precopy(
            num_reqs,
            self._mamba_state_idx_gpu,
            self._mamba_src_col_gpu,
            self._mamba_src_off_gpu,
            input_batch.idx_mapping,
        )

    def prepare_attn(
        self,
        input_batch: InputBatch,
        cudagraph_mode: CUDAGraphMode,
        block_tables: tuple[torch.Tensor, ...],
        slot_mappings: torch.Tensor,
        attn_groups: list[list[AttentionGroup]],
        kv_cache_config: KVCacheConfig,
        for_capture: bool = False,
    ) -> dict[str, Any]:
        if cudagraph_mode == CUDAGraphMode.FULL:
            num_reqs = input_batch.num_reqs_after_padding
            num_tokens = input_batch.num_tokens_after_padding
        else:
            num_reqs = input_batch.num_reqs
            num_tokens = input_batch.num_tokens
        query_start_loc_cpu = torch.from_numpy(input_batch.query_start_loc_np)
        max_query_len = input_batch.num_scheduled_tokens.max().item()
        seq_lens_cpu_upper_bound = input_batch.seq_lens_cpu_upper_bound
        if for_capture:
            # Capture with worst-case max_seq_len so the graph is valid at any replay.
            max_seq_len = self.max_model_len
        else:
            max_seq_len = seq_lens_cpu_upper_bound[:num_reqs].max().item()

        is_prefilling = torch.zeros(num_reqs, dtype=torch.bool, device="cpu")
        is_prefilling[: input_batch.num_reqs] = torch.from_numpy(
            input_batch.is_prefilling_np
        )
        # During CUDAGraph capture, num_decode_draft_tokens_cpu and num_accepted_tokens
        # are created by attn_metadata_builder.build_for_cudagraph_capture, so we only
        # compute them during actual (non-capture) forward execution.
        num_accepted_tokens = None
        num_decode_draft_tokens_cpu = None
        if not for_capture:
            num_accepted_tokens = self.num_accepted_tokens_gpu.new_ones(num_reqs)
            num_accepted_tokens[: input_batch.num_reqs] = self.num_accepted_tokens_gpu[
                input_batch.idx_mapping
            ]

            # GDN uses >= 0 to select spec-decode rows, so non-decode rows
            # need the -1 sentinel rather than a raw zero draft count.
            num_decode_draft_tokens_np = np.full(num_reqs, -1, dtype=np.int32)
            if input_batch.num_draft_tokens_per_req is not None:
                has_draft_tokens = input_batch.num_draft_tokens_per_req > 0
                spec_decode_mask = has_draft_tokens & ~input_batch.is_prefilling_np
                num_decode_draft_tokens_np[: input_batch.num_reqs] = np.where(
                    spec_decode_mask, input_batch.num_draft_tokens_per_req, -1
                )
            num_decode_draft_tokens_cpu = torch.from_numpy(num_decode_draft_tokens_np)

        mamba_attn_metadata = MambaHybridAttnMetadata(
            is_prefilling=is_prefilling,
            num_accepted_tokens=num_accepted_tokens,
            num_decode_draft_tokens_cpu=num_decode_draft_tokens_cpu,
        )
        return build_attn_metadata(
            attn_groups=attn_groups,
            num_reqs=num_reqs,
            num_tokens=num_tokens,
            query_start_loc_gpu=input_batch.query_start_loc,
            query_start_loc_cpu=query_start_loc_cpu,
            max_query_len=max_query_len,
            seq_lens=input_batch.seq_lens,
            max_seq_len=max_seq_len,
            block_tables=block_tables,
            slot_mappings=slot_mappings,
            kv_cache_config=kv_cache_config,
            seq_lens_cpu_upper_bound=seq_lens_cpu_upper_bound,
            dcp_local_seq_lens=input_batch.dcp_local_seq_lens,
            model_specific_attn_metadata=mamba_attn_metadata,
            for_cudagraph_capture=for_capture,
            rswa_prefix_lens=input_batch.prompt_lens,
        )

    def postprocess_state(
        self,
        idx_mapping: torch.Tensor,
        num_sampled: torch.Tensor | int,
        num_computed_tokens: torch.Tensor | None = None,
    ) -> None:
        # Chunked prefill does not sample a token, so num_sampled can be 0.
        # Mamba treats num_accepted_tokens=1 as the neutral non-spec value.
        if not isinstance(num_sampled, int):
            # idx_mapping may contain -1 sentinels (filtered rows) under PP; the
            # kernel skips them rather than scattering with a host-side gather.
            n = idx_mapping.shape[0]
            if n:
                _scatter_num_accepted_kernel[(n,)](
                    idx_mapping, num_sampled, self.num_accepted_tokens_gpu
                )
        else:
            # Fill with single value.
            self.num_accepted_tokens_gpu.index_fill_(
                0, idx_mapping, max(num_sampled, 1)
            )

        # Align: save the running state to the block-aligned position when
        # spec-decode acceptance leaves the sequence non-block-aligned (mirrors
        # the V1 align postprocess). num_computed_tokens already holds the
        # post-step advanced count.
        if (
            self._align_mode
            and num_computed_tokens is not None
            and self._mamba_ctx is not None
        ):
            num_reqs = idx_mapping.shape[0]
            if num_reqs:
                self._mamba_ctx.run_fused_postprocess_align(
                    num_reqs,
                    self.num_accepted_tokens_gpu,
                    self._mamba_state_idx_gpu,
                    num_computed_tokens,
                    idx_mapping,
                )

preprocess_state(input_batch, block_tables, kv_cache_config, num_computed_tokens)

Migrate each request's mamba state across block boundaries before the forward (V1 align semantics, done on GPU). Runs on real batches only (dummy DP/profiling runs skip preprocess_state), and before prepare_attn gathers num_accepted_tokens, so the boundary reset is visible to the forward kernels.

Source code in vllm/v1/worker/gpu/model_states/mamba_hybrid.py
def preprocess_state(
    self,
    input_batch: InputBatch,
    block_tables: tuple[torch.Tensor, ...],
    kv_cache_config: KVCacheConfig,
    num_computed_tokens: torch.Tensor,
) -> None:
    """Migrate each request's mamba state across block boundaries before the
    forward (V1 align semantics, done on GPU). Runs on real batches only
    (dummy DP/profiling runs skip preprocess_state), and before
    ``prepare_attn`` gathers ``num_accepted_tokens``, so the boundary reset
    is visible to the forward kernels.
    """
    if not self._align_mode:
        return
    num_reqs = input_batch.num_reqs
    if num_reqs == 0:
        return
    mamba_group_ids, mamba_spec = self._get_mamba_group_info(kv_cache_config)
    ctx = self._ensure_align_ctx(kv_cache_config, mamba_group_ids, block_tables)

    # The state-advance + pre-copy kernels run every step; they fast-exit per
    # request when src_col < 0 or src_col == dst_col, so no copy happens on
    # steps that don't cross a block boundary. (Skipping the launch entirely
    # would need a V1-style async-D2H of the actual num_computed, since
    # num_computed_tokens_np is an optimistic mirror under async scheduling;
    # the launch cost is ~0.3% of TPOT, so the GPU fast-exit suffices.)
    block = 256
    grid = (triton.cdiv(num_reqs, block),)
    preprocess_mamba_align_fused_kernel[grid](
        input_batch.idx_mapping,
        self._mamba_state_idx_gpu,
        num_computed_tokens,
        input_batch.query_start_loc,
        self.num_accepted_tokens_gpu,
        self._mamba_src_col_gpu,
        self._mamba_src_off_gpu,
        num_reqs,
        BLOCK_SIZE=block,
        MAMBA_BLOCK_SIZE=mamba_spec.block_size,
    )
    ctx.run_fused_precopy(
        num_reqs,
        self._mamba_state_idx_gpu,
        self._mamba_src_col_gpu,
        self._mamba_src_off_gpu,
        input_batch.idx_mapping,
    )