Skip to content

vllm.v1.worker.mamba_utils

Classes:

Functions:

MambaBuffers dataclass

Single owner for all mamba-specific runner buffers.

The two sub-objects have different gates: preprocess is needed whenever mamba_cache_mode == "align"; postprocess_align is needed only when align is combined with speculative decoding on a hybrid model, and is None otherwise.

Source code in vllm/v1/worker/mamba_utils.py
@dataclasses.dataclass
class MambaBuffers:
    """Single owner for all mamba-specific runner buffers.

    The two sub-objects have different gates:
    ``preprocess`` is needed whenever ``mamba_cache_mode == "align"``;
    ``postprocess_align`` is needed only when align is combined with
    speculative decoding on a hybrid model, and is ``None`` otherwise.
    """

    preprocess: MambaCopyBuffers
    postprocess_align: MambaSpecDecodeGPUContext | None

    @classmethod
    def create(
        cls,
        max_num_reqs: int,
        kv_cache_config: KVCacheConfig,
        copy_funcs: tuple[MambaStateCopyFunc, ...],
        make_buffer: Callable[..., CpuGpuBuffer],
        device: torch.device,
        with_postprocess_align: bool,
    ) -> "MambaBuffers":
        return cls(
            preprocess=MambaCopyBuffers.create(
                max_num_reqs, kv_cache_config, copy_funcs, make_buffer
            ),
            postprocess_align=(
                MambaSpecDecodeGPUContext.create(
                    max_num_reqs=max_num_reqs,
                    kv_cache_config=kv_cache_config,
                    num_state_types=len(copy_funcs),
                    device=device,
                    make_buffer=make_buffer,
                )
                if with_postprocess_align
                else None
            ),
        )

MambaSpecDecodeGPUContext dataclass

Context for GPU-side Mamba state copy operations during the fused postprocess path.

Only used when speculative decoding is enabled on a hybrid model (and the mamba_cache_config is in align mode).

Precomputes memory layout metadata (base addresses, strides, element sizes) so the GPU kernel can perform state copies without CPU-GPU sync.

State types are distinguished by conv_width: >0 for conv states (sliding window with offset-based copies), 0 for temporal states (full block copies).

Methods:

Source code in vllm/v1/worker/mamba_utils.py
457
458
459
460
461
462
463
464
465
466
467
468
469
470
471
472
473
474
475
476
477
478
479
480
481
482
483
484
485
486
487
488
489
490
491
492
493
494
495
496
497
498
499
500
501
502
503
504
505
506
507
508
509
510
511
512
513
514
515
516
517
518
519
520
521
522
523
524
525
526
527
528
529
530
531
532
533
534
535
536
537
538
539
540
541
542
543
544
545
546
547
548
549
550
551
552
553
554
555
556
557
558
559
560
561
562
563
564
565
566
567
568
569
570
571
572
573
574
575
576
577
578
579
580
581
582
583
584
585
586
587
588
589
590
591
592
593
594
595
596
597
598
599
600
601
602
603
604
605
606
607
608
609
610
611
612
613
614
615
616
617
618
619
620
621
622
623
624
625
626
627
628
629
630
631
632
633
634
635
636
637
638
639
640
641
642
643
644
645
646
647
648
649
650
651
652
653
654
655
656
657
658
659
660
661
662
663
664
665
666
667
668
669
670
671
672
673
674
675
676
677
678
679
680
681
682
683
684
685
686
687
688
689
690
691
692
693
694
695
696
697
698
699
700
701
702
703
704
705
706
707
708
709
710
711
712
713
714
715
716
717
718
719
720
721
722
723
724
725
726
727
728
729
730
731
732
733
734
735
736
737
738
739
740
741
742
743
744
745
746
747
748
749
750
751
752
753
754
755
756
757
758
759
760
761
762
763
764
765
766
767
768
769
770
771
772
773
774
775
776
777
778
779
780
781
782
783
784
785
786
787
788
789
790
791
792
793
794
795
796
797
798
799
800
801
802
803
804
805
806
807
808
809
810
811
812
813
814
815
816
817
818
819
820
821
822
823
824
825
826
827
828
829
830
831
832
833
834
835
836
837
838
839
840
841
842
@dataclasses.dataclass
class MambaSpecDecodeGPUContext:
    """
    Context for GPU-side Mamba state copy operations during the
    fused postprocess path.

    Only used when speculative decoding is enabled on a hybrid model
    (and the mamba_cache_config is in align mode).

    Precomputes memory layout metadata (base addresses, strides, element sizes)
    so the GPU kernel can perform state copies without CPU-GPU sync.

    State types are distinguished by conv_width: >0 for conv states (sliding
    window with offset-based copies), 0 for temporal states (full block copies).
    """

    # Per-state metadata tensors (shape: [num_layers * num_state_types])
    # These are populated from forward_context during the first forward pass
    state_base_addrs: torch.Tensor  # int64: base address of each state tensor
    state_block_strides: torch.Tensor  # int64: bytes per block
    state_elem_sizes: torch.Tensor  # int32: element size in bytes
    state_inner_sizes: torch.Tensor  # int64: elements in inner dimensions
    state_conv_widths: torch.Tensor  # int32: conv width (0 for temporal states)
    state_group_indices: torch.Tensor  # int32: maps state_idx to group index
    # DS conv row metadata. Zero keeps the single-region copy path.
    state_dim_row_count: torch.Tensor  # int32: per-block dim row count
    state_dim_row_stride: torch.Tensor  # int64: bytes between rows

    # Configuration
    block_size: int
    num_layers: int
    num_state_types: int
    mamba_group_ids: list[int]
    num_groups: int

    # Output buffer for num_accepted_tokens updates
    num_accepted_tokens_out: torch.Tensor

    # Per-group block-table base addresses: int64[num_groups]. Populated in
    # initialize_from_forward_context from the persistent per-group block
    # table tensors (whose data_ptr is stable across steps).
    block_table_ptrs: torch.Tensor
    block_table_stride_req: int = 0

    # Per-request staging buffers (CPU+GPU mirrors). The runner stages
    # values into the CPU view in ``_prepare_inputs`` and the fused kernel
    # reads the GPU side. These only exist when the postprocess kernel is
    # enabled (spec decode + hybrid + align mode).
    mamba_state_idx_buf: CpuGpuBuffer | None = None
    num_scheduled_tokens_buf: CpuGpuBuffer | None = None
    num_computed_tokens_buf: CpuGpuBuffer | None = None
    num_draft_tokens_buf: CpuGpuBuffer | None = None

    # Flag to track if metadata has been populated
    is_initialized: bool = False

    @classmethod
    def create(
        cls,
        max_num_reqs: int,
        kv_cache_config: KVCacheConfig,
        num_state_types: int,
        device: torch.device,
        make_buffer: Callable[..., CpuGpuBuffer],
    ) -> "MambaSpecDecodeGPUContext":
        """Create context with allocated buffers (metadata populated later)."""
        mamba_group_ids, mamba_spec = get_mamba_groups(kv_cache_config)

        # Count total layers across all mamba groups
        num_layers = sum(
            len(kv_cache_config.kv_cache_groups[gid].layer_names)
            for gid in mamba_group_ids
        )
        total_states = num_layers * num_state_types

        return cls(
            state_base_addrs=torch.zeros(
                total_states, dtype=torch.int64, device=device
            ),
            state_block_strides=torch.zeros(
                total_states, dtype=torch.int64, device=device
            ),
            state_elem_sizes=torch.zeros(
                total_states, dtype=torch.int32, device=device
            ),
            state_inner_sizes=torch.zeros(
                total_states, dtype=torch.int64, device=device
            ),
            state_conv_widths=torch.zeros(
                total_states, dtype=torch.int32, device=device
            ),
            state_group_indices=torch.zeros(
                total_states, dtype=torch.int32, device=device
            ),
            state_dim_row_count=torch.zeros(
                total_states, dtype=torch.int32, device=device
            ),
            state_dim_row_stride=torch.zeros(
                total_states, dtype=torch.int64, device=device
            ),
            block_size=mamba_spec.block_size,
            num_layers=num_layers,
            num_state_types=num_state_types,
            mamba_group_ids=mamba_group_ids,
            num_groups=len(mamba_group_ids),
            num_accepted_tokens_out=torch.zeros(
                max_num_reqs, dtype=torch.int32, device=device
            ),
            block_table_ptrs=torch.zeros(
                len(mamba_group_ids), dtype=torch.int64, device=device
            ),
            mamba_state_idx_buf=make_buffer(max_num_reqs, dtype=torch.int32),
            num_scheduled_tokens_buf=make_buffer(max_num_reqs, dtype=torch.int32),
            num_computed_tokens_buf=make_buffer(max_num_reqs, dtype=torch.int32),
            num_draft_tokens_buf=make_buffer(max_num_reqs, dtype=torch.int32),
            is_initialized=False,
        )

    def initialize_from_forward_context(
        self,
        kv_cache_config: KVCacheConfig,
        forward_context: dict[str, Any],
        mamba_state_copy_funcs: tuple[MambaStateCopyFunc, ...],
        block_tables: list[torch.Tensor],
    ) -> None:
        """
        Extract and cache memory layout metadata from Mamba state tensors.

        This method populates the pre-allocated metadata tensors with information
        needed by `postprocess_mamba_fused_kernel` to perform state copies entirely
        on the GPU without CPU-GPU synchronization.

        For each Mamba layer and state type, the following metadata is extracted:
        - state_base_addrs: GPU memory address (data_ptr) of the state tensor
        - state_block_strides: Bytes between consecutive blocks (stride * elem_size)
        - state_elem_sizes: Element size in bytes (e.g., 2 for float16)
        - state_inner_sizes: For conv states, elements per conv position (stride(1)),
          used to compute offset when slicing state[block, offset:]. For temporal
          states, this field is unused (set to 1).
        - state_conv_widths: Conv dimension size for conv states, 0 for temporal states

        The conv vs temporal state type is detected by inspecting the copy function
        name: functions containing "conv" are treated as conv states.

        This method is idempotent - it only executes once (guarded by is_initialized
        flag) since the metadata is static after model loading.

        Args:
            kv_cache_config: Configuration containing KV cache group info and
                layer name mappings.
            forward_context: Dictionary mapping layer names to attention objects,
                populated after the model is loaded. Each attention object must
                have a `kv_cache` attribute containing the list of state tensors.
            mamba_state_copy_funcs: Tuple of copy functions (one per state type)
                used to determine whether each state is a conv or temporal state.
            block_tables: per-mamba-group persistent block-table tensors, in
                the same order as `mamba_group_ids`. Their `data_ptr()` /
                `stride(0)` are captured once for the kernel to index into.
        """
        if self.is_initialized:
            return

        idx = 0
        for group_local_idx, mamba_group_id in enumerate(self.mamba_group_ids):
            layer_names = kv_cache_config.kv_cache_groups[mamba_group_id].layer_names
            for layer_name in layer_names:
                attention = forward_context[layer_name]
                kv_caches: list[torch.Tensor] = attention.kv_cache

                for state_type_idx, state in enumerate(kv_caches):
                    # Base address
                    self.state_base_addrs[idx] = state.data_ptr()

                    # Block stride (bytes between consecutive blocks)
                    # state shape: [num_blocks, ...], stride(0) = elements per block
                    if state.dim() > 1:
                        block_stride_elems = state.stride(0)
                    else:
                        block_stride_elems = state.numel()
                    self.state_block_strides[idx] = (
                        block_stride_elems * state.element_size()
                    )

                    # Element size
                    self.state_elem_sizes[idx] = state.element_size()

                    copy_func = mamba_state_copy_funcs[state_type_idx]
                    assert (
                        copy_func is get_conv_copy_spec
                        or copy_func is get_temporal_copy_spec
                    ), f"unexpected copy func: {copy_func}"
                    if copy_func is get_conv_copy_spec:
                        if state.dim() != 3:
                            raise ValueError(
                                "Expected 3D conv state cache, got "
                                f"shape {tuple(state.shape)}"
                            )
                        if is_conv_state_dim_first():
                            # DS layout: state_len is the slide axis.
                            self.state_conv_widths[idx] = state.size(2)
                            self.state_inner_sizes[idx] = 1
                            self.state_dim_row_count[idx] = state.size(1)
                            self.state_dim_row_stride[idx] = (
                                state.stride(1) * state.element_size()
                            )
                        else:
                            # SD layout: dim is contiguous.
                            self.state_conv_widths[idx] = state.size(1)
                            self.state_inner_sizes[idx] = state.stride(1)
                    else:
                        # Temporal state: inner_size = natural elements per
                        # block (prod of inner dims).  The kernel uses this
                        # to compute copy_size = inner_size * elem_size,
                        # which gives the correct byte count even when the
                        # state tensor is as_strided with padded page strides
                        # (state_block_stride would be the page size, too big).
                        self.state_conv_widths[idx] = 0
                        self.state_inner_sizes[idx] = (
                            state[0].numel() if state.dim() > 1 else 1
                        )

                    self.state_group_indices[idx] = group_local_idx
                    idx += 1

        # Cache per-group block-table base addresses and per-request stride.
        # `block_tables[i]` is the persistent 2D int32 block-table tensor for
        # `mamba_group_ids[i]`; `data_ptr()` / `stride(0)` are stable for the
        # engine's lifetime, so we capture them once here.
        assert len(block_tables) == self.num_groups, (
            f"expected {self.num_groups} block tables, got {len(block_tables)}"
        )
        strides = {bt.stride(0) for bt in block_tables}
        assert len(strides) == 1, (
            f"all mamba block tables must share stride(0), got {strides}"
        )
        self.block_table_stride_req = int(next(iter(strides)))
        for i, bt in enumerate(block_tables):
            self.block_table_ptrs[i] = bt.data_ptr()

        self.is_initialized = True

    def run_fused_postprocess(
        self,
        num_reqs: int,
        num_accepted_tokens_gpu: torch.Tensor,
        mamba_state_idx_gpu: torch.Tensor,
        num_scheduled_tokens_gpu: torch.Tensor,
        num_computed_tokens_gpu: torch.Tensor,
        num_draft_tokens_gpu: torch.Tensor,
    ) -> None:
        """
        Run the fused postprocess_mamba kernel on GPU.

        This computes decisions and performs mamba state copies entirely on GPU,
        eliminating the CPU-GPU sync that was previously needed.

        Args:
            num_reqs: Number of active requests
            num_accepted_tokens_gpu: [num_reqs] accepted token counts
            mamba_state_idx_gpu: [num_reqs] source block indices
            num_scheduled_tokens_gpu: [num_reqs] scheduled token counts
            num_computed_tokens_gpu: [num_reqs] computed token counts
            num_draft_tokens_gpu: [num_reqs] draft token counts
        """
        if num_reqs == 0 or not self.is_initialized:
            return

        # Initialize output to current values (unchanged unless src==dst)
        self.num_accepted_tokens_out[:num_reqs].copy_(
            num_accepted_tokens_gpu[:num_reqs]
        )

        total_states = self.num_layers * self.num_state_types
        grid = (num_reqs, total_states)

        postprocess_mamba_fused_kernel[grid](
            num_accepted_tokens_gpu,
            mamba_state_idx_gpu,
            num_scheduled_tokens_gpu,
            num_computed_tokens_gpu,
            num_draft_tokens_gpu,
            self.block_table_ptrs,
            self.block_table_stride_req,
            self.state_base_addrs,
            self.state_block_strides,
            self.state_elem_sizes,
            self.state_inner_sizes,
            self.state_conv_widths,
            self.state_group_indices,
            self.state_dim_row_count,
            self.state_dim_row_stride,
            self.num_accepted_tokens_out,
            None,  # idx_mapping: V1 decision arrays are already in req order
            num_reqs,
            block_size=self.block_size,
            COPY_BLOCK_SIZE=1024,
            CONV_STATE_DIM_FIRST=is_conv_state_dim_first(),
        )

    def run_fused_precopy(
        self,
        num_reqs: int,
        state_idx_gpu: torch.Tensor,
        src_col_gpu: torch.Tensor,
        token_bias_gpu: torch.Tensor,
        idx_mapping: torch.Tensor,
    ) -> None:
        """Pre-copy each request's previous running block into its new window
        block before the forward pass (V2 align boundary migration).

        Args:
            num_reqs: Number of active requests (batch order).
            state_idx_gpu: [max_reqs] post-advance dst block column per req slot.
            src_col_gpu: [max_reqs] pre-advance src block column (-1 = fresh).
            token_bias_gpu: [max_reqs] accepted-token bias (num_accepted - 1).
            idx_mapping: [num_reqs] batch_idx -> req_state_idx (-1 to skip).
        """
        if num_reqs == 0 or not self.is_initialized:
            return
        total_states = self.num_layers * self.num_state_types
        grid = (num_reqs, total_states)
        precopy_mamba_align_fused_kernel[grid](
            state_idx_gpu,
            src_col_gpu,
            token_bias_gpu,
            self.block_table_ptrs,
            self.block_table_stride_req,
            self.state_base_addrs,
            self.state_block_strides,
            self.state_elem_sizes,
            self.state_inner_sizes,
            self.state_conv_widths,
            self.state_group_indices,
            self.state_dim_row_count,
            self.state_dim_row_stride,
            idx_mapping,
            num_reqs,
            COPY_BLOCK_SIZE=1024,
            CONV_STATE_DIM_FIRST=is_conv_state_dim_first(),
        )

    def run_fused_postprocess_align(
        self,
        num_reqs: int,
        num_accepted_tokens_gpu: torch.Tensor,
        state_idx_gpu: torch.Tensor,
        new_num_computed_tokens_gpu: torch.Tensor,
        idx_mapping: torch.Tensor,
    ) -> None:
        """V2 align postprocess: save the running state to the block-aligned
        position after spec-decode acceptance leaves the sequence non-aligned.

        ``num_accepted_tokens_gpu`` is updated in place (reset to 1 when the
        accepted position stays in the running block); ``new_num_computed_tokens``
        already holds the post-step computed count (PRECOMPUTED_NEW_COMPUTED).
        ``idx_mapping`` maps batch row -> req-state slot (HAS_IDX_MAPPING).
        """
        if num_reqs == 0 or not self.is_initialized:
            return
        total_states = self.num_layers * self.num_state_types
        grid = (num_reqs, total_states)
        postprocess_mamba_fused_kernel[grid](
            num_accepted_tokens_gpu,
            state_idx_gpu,
            None,  # num_scheduled: unused under PRECOMPUTED_NEW_COMPUTED
            new_num_computed_tokens_gpu,
            None,  # num_draft: unused under PRECOMPUTED_NEW_COMPUTED
            self.block_table_ptrs,
            self.block_table_stride_req,
            self.state_base_addrs,
            self.state_block_strides,
            self.state_elem_sizes,
            self.state_inner_sizes,
            self.state_conv_widths,
            self.state_group_indices,
            self.state_dim_row_count,
            self.state_dim_row_stride,
            None,  # num_accepted_out: V2 updates num_accepted in place
            idx_mapping,
            num_reqs,
            block_size=self.block_size,
            COPY_BLOCK_SIZE=1024,
            CONV_STATE_DIM_FIRST=is_conv_state_dim_first(),
            HAS_IDX_MAPPING=True,
            PRECOMPUTED_NEW_COMPUTED=True,
        )

create(max_num_reqs, kv_cache_config, num_state_types, device, make_buffer) classmethod

Create context with allocated buffers (metadata populated later).

Source code in vllm/v1/worker/mamba_utils.py
@classmethod
def create(
    cls,
    max_num_reqs: int,
    kv_cache_config: KVCacheConfig,
    num_state_types: int,
    device: torch.device,
    make_buffer: Callable[..., CpuGpuBuffer],
) -> "MambaSpecDecodeGPUContext":
    """Create context with allocated buffers (metadata populated later)."""
    mamba_group_ids, mamba_spec = get_mamba_groups(kv_cache_config)

    # Count total layers across all mamba groups
    num_layers = sum(
        len(kv_cache_config.kv_cache_groups[gid].layer_names)
        for gid in mamba_group_ids
    )
    total_states = num_layers * num_state_types

    return cls(
        state_base_addrs=torch.zeros(
            total_states, dtype=torch.int64, device=device
        ),
        state_block_strides=torch.zeros(
            total_states, dtype=torch.int64, device=device
        ),
        state_elem_sizes=torch.zeros(
            total_states, dtype=torch.int32, device=device
        ),
        state_inner_sizes=torch.zeros(
            total_states, dtype=torch.int64, device=device
        ),
        state_conv_widths=torch.zeros(
            total_states, dtype=torch.int32, device=device
        ),
        state_group_indices=torch.zeros(
            total_states, dtype=torch.int32, device=device
        ),
        state_dim_row_count=torch.zeros(
            total_states, dtype=torch.int32, device=device
        ),
        state_dim_row_stride=torch.zeros(
            total_states, dtype=torch.int64, device=device
        ),
        block_size=mamba_spec.block_size,
        num_layers=num_layers,
        num_state_types=num_state_types,
        mamba_group_ids=mamba_group_ids,
        num_groups=len(mamba_group_ids),
        num_accepted_tokens_out=torch.zeros(
            max_num_reqs, dtype=torch.int32, device=device
        ),
        block_table_ptrs=torch.zeros(
            len(mamba_group_ids), dtype=torch.int64, device=device
        ),
        mamba_state_idx_buf=make_buffer(max_num_reqs, dtype=torch.int32),
        num_scheduled_tokens_buf=make_buffer(max_num_reqs, dtype=torch.int32),
        num_computed_tokens_buf=make_buffer(max_num_reqs, dtype=torch.int32),
        num_draft_tokens_buf=make_buffer(max_num_reqs, dtype=torch.int32),
        is_initialized=False,
    )

initialize_from_forward_context(kv_cache_config, forward_context, mamba_state_copy_funcs, block_tables)

Extract and cache memory layout metadata from Mamba state tensors.

This method populates the pre-allocated metadata tensors with information needed by postprocess_mamba_fused_kernel to perform state copies entirely on the GPU without CPU-GPU synchronization.

For each Mamba layer and state type, the following metadata is extracted: - state_base_addrs: GPU memory address (data_ptr) of the state tensor - state_block_strides: Bytes between consecutive blocks (stride * elem_size) - state_elem_sizes: Element size in bytes (e.g., 2 for float16) - state_inner_sizes: For conv states, elements per conv position (stride(1)), used to compute offset when slicing state[block, offset:]. For temporal states, this field is unused (set to 1). - state_conv_widths: Conv dimension size for conv states, 0 for temporal states

The conv vs temporal state type is detected by inspecting the copy function name: functions containing "conv" are treated as conv states.

This method is idempotent - it only executes once (guarded by is_initialized flag) since the metadata is static after model loading.

Parameters:

  • kv_cache_config

    (KVCacheConfig) –

    Configuration containing KV cache group info and layer name mappings.

  • forward_context

    (dict[str, Any]) –

    Dictionary mapping layer names to attention objects, populated after the model is loaded. Each attention object must have a kv_cache attribute containing the list of state tensors.

  • mamba_state_copy_funcs

    (tuple[MambaStateCopyFunc, ...]) –

    Tuple of copy functions (one per state type) used to determine whether each state is a conv or temporal state.

  • block_tables

    (list[Tensor]) –

    per-mamba-group persistent block-table tensors, in the same order as mamba_group_ids. Their data_ptr() / stride(0) are captured once for the kernel to index into.

Source code in vllm/v1/worker/mamba_utils.py
def initialize_from_forward_context(
    self,
    kv_cache_config: KVCacheConfig,
    forward_context: dict[str, Any],
    mamba_state_copy_funcs: tuple[MambaStateCopyFunc, ...],
    block_tables: list[torch.Tensor],
) -> None:
    """
    Extract and cache memory layout metadata from Mamba state tensors.

    This method populates the pre-allocated metadata tensors with information
    needed by `postprocess_mamba_fused_kernel` to perform state copies entirely
    on the GPU without CPU-GPU synchronization.

    For each Mamba layer and state type, the following metadata is extracted:
    - state_base_addrs: GPU memory address (data_ptr) of the state tensor
    - state_block_strides: Bytes between consecutive blocks (stride * elem_size)
    - state_elem_sizes: Element size in bytes (e.g., 2 for float16)
    - state_inner_sizes: For conv states, elements per conv position (stride(1)),
      used to compute offset when slicing state[block, offset:]. For temporal
      states, this field is unused (set to 1).
    - state_conv_widths: Conv dimension size for conv states, 0 for temporal states

    The conv vs temporal state type is detected by inspecting the copy function
    name: functions containing "conv" are treated as conv states.

    This method is idempotent - it only executes once (guarded by is_initialized
    flag) since the metadata is static after model loading.

    Args:
        kv_cache_config: Configuration containing KV cache group info and
            layer name mappings.
        forward_context: Dictionary mapping layer names to attention objects,
            populated after the model is loaded. Each attention object must
            have a `kv_cache` attribute containing the list of state tensors.
        mamba_state_copy_funcs: Tuple of copy functions (one per state type)
            used to determine whether each state is a conv or temporal state.
        block_tables: per-mamba-group persistent block-table tensors, in
            the same order as `mamba_group_ids`. Their `data_ptr()` /
            `stride(0)` are captured once for the kernel to index into.
    """
    if self.is_initialized:
        return

    idx = 0
    for group_local_idx, mamba_group_id in enumerate(self.mamba_group_ids):
        layer_names = kv_cache_config.kv_cache_groups[mamba_group_id].layer_names
        for layer_name in layer_names:
            attention = forward_context[layer_name]
            kv_caches: list[torch.Tensor] = attention.kv_cache

            for state_type_idx, state in enumerate(kv_caches):
                # Base address
                self.state_base_addrs[idx] = state.data_ptr()

                # Block stride (bytes between consecutive blocks)
                # state shape: [num_blocks, ...], stride(0) = elements per block
                if state.dim() > 1:
                    block_stride_elems = state.stride(0)
                else:
                    block_stride_elems = state.numel()
                self.state_block_strides[idx] = (
                    block_stride_elems * state.element_size()
                )

                # Element size
                self.state_elem_sizes[idx] = state.element_size()

                copy_func = mamba_state_copy_funcs[state_type_idx]
                assert (
                    copy_func is get_conv_copy_spec
                    or copy_func is get_temporal_copy_spec
                ), f"unexpected copy func: {copy_func}"
                if copy_func is get_conv_copy_spec:
                    if state.dim() != 3:
                        raise ValueError(
                            "Expected 3D conv state cache, got "
                            f"shape {tuple(state.shape)}"
                        )
                    if is_conv_state_dim_first():
                        # DS layout: state_len is the slide axis.
                        self.state_conv_widths[idx] = state.size(2)
                        self.state_inner_sizes[idx] = 1
                        self.state_dim_row_count[idx] = state.size(1)
                        self.state_dim_row_stride[idx] = (
                            state.stride(1) * state.element_size()
                        )
                    else:
                        # SD layout: dim is contiguous.
                        self.state_conv_widths[idx] = state.size(1)
                        self.state_inner_sizes[idx] = state.stride(1)
                else:
                    # Temporal state: inner_size = natural elements per
                    # block (prod of inner dims).  The kernel uses this
                    # to compute copy_size = inner_size * elem_size,
                    # which gives the correct byte count even when the
                    # state tensor is as_strided with padded page strides
                    # (state_block_stride would be the page size, too big).
                    self.state_conv_widths[idx] = 0
                    self.state_inner_sizes[idx] = (
                        state[0].numel() if state.dim() > 1 else 1
                    )

                self.state_group_indices[idx] = group_local_idx
                idx += 1

    # Cache per-group block-table base addresses and per-request stride.
    # `block_tables[i]` is the persistent 2D int32 block-table tensor for
    # `mamba_group_ids[i]`; `data_ptr()` / `stride(0)` are stable for the
    # engine's lifetime, so we capture them once here.
    assert len(block_tables) == self.num_groups, (
        f"expected {self.num_groups} block tables, got {len(block_tables)}"
    )
    strides = {bt.stride(0) for bt in block_tables}
    assert len(strides) == 1, (
        f"all mamba block tables must share stride(0), got {strides}"
    )
    self.block_table_stride_req = int(next(iter(strides)))
    for i, bt in enumerate(block_tables):
        self.block_table_ptrs[i] = bt.data_ptr()

    self.is_initialized = True

run_fused_postprocess(num_reqs, num_accepted_tokens_gpu, mamba_state_idx_gpu, num_scheduled_tokens_gpu, num_computed_tokens_gpu, num_draft_tokens_gpu)

Run the fused postprocess_mamba kernel on GPU.

This computes decisions and performs mamba state copies entirely on GPU, eliminating the CPU-GPU sync that was previously needed.

Parameters:

  • num_reqs

    (int) –

    Number of active requests

  • num_accepted_tokens_gpu

    (Tensor) –

    [num_reqs] accepted token counts

  • mamba_state_idx_gpu

    (Tensor) –

    [num_reqs] source block indices

  • num_scheduled_tokens_gpu

    (Tensor) –

    [num_reqs] scheduled token counts

  • num_computed_tokens_gpu

    (Tensor) –

    [num_reqs] computed token counts

  • num_draft_tokens_gpu

    (Tensor) –

    [num_reqs] draft token counts

Source code in vllm/v1/worker/mamba_utils.py
def run_fused_postprocess(
    self,
    num_reqs: int,
    num_accepted_tokens_gpu: torch.Tensor,
    mamba_state_idx_gpu: torch.Tensor,
    num_scheduled_tokens_gpu: torch.Tensor,
    num_computed_tokens_gpu: torch.Tensor,
    num_draft_tokens_gpu: torch.Tensor,
) -> None:
    """
    Run the fused postprocess_mamba kernel on GPU.

    This computes decisions and performs mamba state copies entirely on GPU,
    eliminating the CPU-GPU sync that was previously needed.

    Args:
        num_reqs: Number of active requests
        num_accepted_tokens_gpu: [num_reqs] accepted token counts
        mamba_state_idx_gpu: [num_reqs] source block indices
        num_scheduled_tokens_gpu: [num_reqs] scheduled token counts
        num_computed_tokens_gpu: [num_reqs] computed token counts
        num_draft_tokens_gpu: [num_reqs] draft token counts
    """
    if num_reqs == 0 or not self.is_initialized:
        return

    # Initialize output to current values (unchanged unless src==dst)
    self.num_accepted_tokens_out[:num_reqs].copy_(
        num_accepted_tokens_gpu[:num_reqs]
    )

    total_states = self.num_layers * self.num_state_types
    grid = (num_reqs, total_states)

    postprocess_mamba_fused_kernel[grid](
        num_accepted_tokens_gpu,
        mamba_state_idx_gpu,
        num_scheduled_tokens_gpu,
        num_computed_tokens_gpu,
        num_draft_tokens_gpu,
        self.block_table_ptrs,
        self.block_table_stride_req,
        self.state_base_addrs,
        self.state_block_strides,
        self.state_elem_sizes,
        self.state_inner_sizes,
        self.state_conv_widths,
        self.state_group_indices,
        self.state_dim_row_count,
        self.state_dim_row_stride,
        self.num_accepted_tokens_out,
        None,  # idx_mapping: V1 decision arrays are already in req order
        num_reqs,
        block_size=self.block_size,
        COPY_BLOCK_SIZE=1024,
        CONV_STATE_DIM_FIRST=is_conv_state_dim_first(),
    )

run_fused_postprocess_align(num_reqs, num_accepted_tokens_gpu, state_idx_gpu, new_num_computed_tokens_gpu, idx_mapping)

V2 align postprocess: save the running state to the block-aligned position after spec-decode acceptance leaves the sequence non-aligned.

num_accepted_tokens_gpu is updated in place (reset to 1 when the accepted position stays in the running block); new_num_computed_tokens already holds the post-step computed count (PRECOMPUTED_NEW_COMPUTED). idx_mapping maps batch row -> req-state slot (HAS_IDX_MAPPING).

Source code in vllm/v1/worker/mamba_utils.py
def run_fused_postprocess_align(
    self,
    num_reqs: int,
    num_accepted_tokens_gpu: torch.Tensor,
    state_idx_gpu: torch.Tensor,
    new_num_computed_tokens_gpu: torch.Tensor,
    idx_mapping: torch.Tensor,
) -> None:
    """V2 align postprocess: save the running state to the block-aligned
    position after spec-decode acceptance leaves the sequence non-aligned.

    ``num_accepted_tokens_gpu`` is updated in place (reset to 1 when the
    accepted position stays in the running block); ``new_num_computed_tokens``
    already holds the post-step computed count (PRECOMPUTED_NEW_COMPUTED).
    ``idx_mapping`` maps batch row -> req-state slot (HAS_IDX_MAPPING).
    """
    if num_reqs == 0 or not self.is_initialized:
        return
    total_states = self.num_layers * self.num_state_types
    grid = (num_reqs, total_states)
    postprocess_mamba_fused_kernel[grid](
        num_accepted_tokens_gpu,
        state_idx_gpu,
        None,  # num_scheduled: unused under PRECOMPUTED_NEW_COMPUTED
        new_num_computed_tokens_gpu,
        None,  # num_draft: unused under PRECOMPUTED_NEW_COMPUTED
        self.block_table_ptrs,
        self.block_table_stride_req,
        self.state_base_addrs,
        self.state_block_strides,
        self.state_elem_sizes,
        self.state_inner_sizes,
        self.state_conv_widths,
        self.state_group_indices,
        self.state_dim_row_count,
        self.state_dim_row_stride,
        None,  # num_accepted_out: V2 updates num_accepted in place
        idx_mapping,
        num_reqs,
        block_size=self.block_size,
        COPY_BLOCK_SIZE=1024,
        CONV_STATE_DIM_FIRST=is_conv_state_dim_first(),
        HAS_IDX_MAPPING=True,
        PRECOMPUTED_NEW_COMPUTED=True,
    )

run_fused_precopy(num_reqs, state_idx_gpu, src_col_gpu, token_bias_gpu, idx_mapping)

Pre-copy each request's previous running block into its new window block before the forward pass (V2 align boundary migration).

Parameters:

  • num_reqs

    (int) –

    Number of active requests (batch order).

  • state_idx_gpu

    (Tensor) –

    [max_reqs] post-advance dst block column per req slot.

  • src_col_gpu

    (Tensor) –

    [max_reqs] pre-advance src block column (-1 = fresh).

  • token_bias_gpu

    (Tensor) –

    [max_reqs] accepted-token bias (num_accepted - 1).

  • idx_mapping

    (Tensor) –

    [num_reqs] batch_idx -> req_state_idx (-1 to skip).

Source code in vllm/v1/worker/mamba_utils.py
def run_fused_precopy(
    self,
    num_reqs: int,
    state_idx_gpu: torch.Tensor,
    src_col_gpu: torch.Tensor,
    token_bias_gpu: torch.Tensor,
    idx_mapping: torch.Tensor,
) -> None:
    """Pre-copy each request's previous running block into its new window
    block before the forward pass (V2 align boundary migration).

    Args:
        num_reqs: Number of active requests (batch order).
        state_idx_gpu: [max_reqs] post-advance dst block column per req slot.
        src_col_gpu: [max_reqs] pre-advance src block column (-1 = fresh).
        token_bias_gpu: [max_reqs] accepted-token bias (num_accepted - 1).
        idx_mapping: [num_reqs] batch_idx -> req_state_idx (-1 to skip).
    """
    if num_reqs == 0 or not self.is_initialized:
        return
    total_states = self.num_layers * self.num_state_types
    grid = (num_reqs, total_states)
    precopy_mamba_align_fused_kernel[grid](
        state_idx_gpu,
        src_col_gpu,
        token_bias_gpu,
        self.block_table_ptrs,
        self.block_table_stride_req,
        self.state_base_addrs,
        self.state_block_strides,
        self.state_elem_sizes,
        self.state_inner_sizes,
        self.state_conv_widths,
        self.state_group_indices,
        self.state_dim_row_count,
        self.state_dim_row_stride,
        idx_mapping,
        num_reqs,
        COPY_BLOCK_SIZE=1024,
        CONV_STATE_DIM_FIRST=is_conv_state_dim_first(),
    )

_copy_mamba_state_block(state_idx, bt_row_idx, src_col, dst_col, token_bias, block_table_ptrs_ptr, block_table_stride_req, state_base_addrs_ptr, state_block_strides_ptr, state_elem_sizes_ptr, state_inner_sizes_ptr, state_conv_widths_ptr, state_group_indices_ptr, state_dim_row_count_ptr, state_dim_row_stride_ptr, COPY_BLOCK_SIZE, CONV_STATE_DIM_FIRST)

Copy one (layer, state-type) mamba state block between block columns.

Shared copy body of postprocess_mamba_fused_kernel and precopy_mamba_align_fused_kernel, mirroring the V1 copy specs (get_conv_copy_spec / get_temporal_copy_spec): - conv state (conv_width > 0): shift the window by token_bias tokens, state[bt[src_col], token_bias:] -> state[bt[dst_col], :conv_width - token_bias] - temporal state: token_bias selects the accepted speculative column, state[bt[src_col + token_bias]] -> state[bt[dst_col]]

The caller owns the decision logic (which columns, whether to copy); this device function only performs the byte copy for the given metadata slot.

Source code in vllm/v1/worker/mamba_utils.py
@triton.jit
def _copy_mamba_state_block(
    state_idx,
    bt_row_idx,
    src_col,
    dst_col,
    token_bias,
    block_table_ptrs_ptr,
    block_table_stride_req,
    state_base_addrs_ptr,
    state_block_strides_ptr,
    state_elem_sizes_ptr,
    state_inner_sizes_ptr,
    state_conv_widths_ptr,
    state_group_indices_ptr,
    # DS conv row metadata. Zero keeps the single-region copy path.
    state_dim_row_count_ptr,
    state_dim_row_stride_ptr,
    COPY_BLOCK_SIZE: tl.constexpr,
    CONV_STATE_DIM_FIRST: tl.constexpr,
):
    """Copy one (layer, state-type) mamba state block between block columns.

    Shared copy body of ``postprocess_mamba_fused_kernel`` and
    ``precopy_mamba_align_fused_kernel``, mirroring the V1 copy specs
    (``get_conv_copy_spec`` / ``get_temporal_copy_spec``):
    - conv state (conv_width > 0): shift the window by ``token_bias`` tokens,
      ``state[bt[src_col], token_bias:] ->
      state[bt[dst_col], :conv_width - token_bias]``
    - temporal state: ``token_bias`` selects the accepted speculative column,
      ``state[bt[src_col + token_bias]] -> state[bt[dst_col]]``

    The caller owns the decision logic (which columns, whether to copy); this
    device function only performs the byte copy for the given metadata slot.
    """
    state_base_addr = tl.load(state_base_addrs_ptr + state_idx)
    state_block_stride = tl.load(state_block_strides_ptr + state_idx)
    state_elem_size = tl.load(state_elem_sizes_ptr + state_idx)
    state_inner_size = tl.load(state_inner_sizes_ptr + state_idx)
    conv_width = tl.load(state_conv_widths_ptr + state_idx)

    # Load the group index for this state, then index into the correct
    # group's block table. Each mamba group has independently allocated
    # physical blocks. Reinterpret as int32* since block ids are int32.
    group_idx = tl.load(state_group_indices_ptr + state_idx).to(tl.int64)
    group_base_addr = tl.load(block_table_ptrs_ptr + group_idx)
    block_table_typed = group_base_addr.to(tl.pointer_type(tl.int32))
    block_table_base = block_table_typed + bt_row_idx * block_table_stride_req

    # Widen block ids to int64 before they reach `block_id * state_block_stride`
    # below: state_block_stride can exceed 2**31 bytes for large mamba caches,
    # and Triton would otherwise do the multiply in int32 and wrap.
    dest_block_id = tl.load(block_table_base + dst_col).to(tl.int64)
    dst_addr = state_base_addr + dest_block_id * state_block_stride

    is_conv_state = conv_width > 0

    if CONV_STATE_DIM_FIRST and is_conv_state:
        # DS conv layout: state_len is the slide axis; copy per dim row.
        src_block_id = tl.load(block_table_base + src_col).to(tl.int64)
        dim_rows = tl.load(state_dim_row_count_ptr + state_idx)
        row_stride = tl.load(state_dim_row_stride_ptr + state_idx)
        per_row_bytes = (conv_width - token_bias).to(tl.int64) * state_elem_size
        bias_bytes = token_bias.to(tl.int64) * state_elem_size
        src_block_addr = state_base_addr + src_block_id * state_block_stride
        offsets = tl.arange(0, COPY_BLOCK_SIZE)
        for d in range(0, dim_rows):
            row_src = src_block_addr + d * row_stride + bias_bytes
            row_dst = dst_addr + d * row_stride
            for i in range(0, per_row_bytes, COPY_BLOCK_SIZE):
                mask = (i + offsets) < per_row_bytes
                curr_src = (row_src + i + offsets).to(tl.pointer_type(tl.uint8))
                curr_dst = (row_dst + i + offsets).to(tl.pointer_type(tl.uint8))
                data = tl.load(curr_src, mask=mask)
                tl.store(curr_dst, data, mask=mask)
        return

    if is_conv_state:
        # SD conv: copy
        #   state[bt[src_col], token_bias:] ->
        #   state[bt[dst_col], :conv_width - token_bias]
        src_block_id = tl.load(block_table_base + src_col).to(tl.int64)
        src_offset = token_bias.to(tl.int64) * state_inner_size * state_elem_size
        src_addr = state_base_addr + src_block_id * state_block_stride + src_offset
        num_elems_to_copy = (conv_width - token_bias).to(tl.int64) * state_inner_size
        copy_size = num_elems_to_copy * state_elem_size
    else:
        # Temporal state: copy state[bt[src_col + token_bias]] -> state[bt[dst_col]]
        actual_src_block_id = tl.load(block_table_base + src_col + token_bias).to(
            tl.int64
        )
        src_addr = state_base_addr + actual_src_block_id * state_block_stride
        # Use natural block data size (inner_size * elem_size), NOT
        # state_block_stride which is the page stride and can exceed the
        # actual data when the state tensor uses as_strided page padding.
        copy_size = state_inner_size * state_elem_size

    offsets = tl.arange(0, COPY_BLOCK_SIZE)
    for i in range(0, copy_size, COPY_BLOCK_SIZE):
        mask = (i + offsets) < copy_size
        curr_src = (src_addr + i + offsets).to(tl.pointer_type(tl.uint8))
        curr_dst = (dst_addr + i + offsets).to(tl.pointer_type(tl.uint8))
        data = tl.load(curr_src, mask=mask)
        tl.store(curr_dst, data, mask=mask)

cleanup_mamba_state_idx(scheduler_output, mamba_state_idx)

Pop stale mamba_state_idx entries for finished/preempted/resumed reqs.

Force-preempted requests (e.g., during reset_prefix_cache / KV cache flush) appear in resumed_req_ids without a corresponding entry in preempted_req_ids, leaving stale entries that can point to block indices beyond the new (smaller) block allocation.

Source code in vllm/v1/worker/mamba_utils.py
def cleanup_mamba_state_idx(
    scheduler_output: SchedulerOutput,
    mamba_state_idx: dict[str, int],
) -> None:
    """Pop stale `mamba_state_idx` entries for finished/preempted/resumed reqs.

    Force-preempted requests (e.g., during reset_prefix_cache / KV cache
    flush) appear in resumed_req_ids without a corresponding entry in
    preempted_req_ids, leaving stale entries that can point to block
    indices beyond the new (smaller) block allocation.
    """
    finished_req_ids = scheduler_output.finished_req_ids
    preempted_req_ids = scheduler_output.preempted_req_ids or set()
    resumed_req_ids = scheduler_output.scheduled_cached_reqs.resumed_req_ids
    for req_id in itertools.chain(finished_req_ids, preempted_req_ids, resumed_req_ids):
        mamba_state_idx.pop(req_id, None)

postprocess_mamba_align_gpu(*, bufs, num_reqs, num_accepted_tokens_gpu, num_accepted_tokens_cpu_tensor, input_batch, kv_cache_config, forward_context, mamba_state_copy_funcs)

GPU-side mamba postprocess for spec decode + hybrid + align mode.

Lazily binds the fused-kernel context to the persistent block tables and forward-context state pointers on the first call, runs the fused kernel, and async-copies the per-request accepted-token counts back to the input batch's CPU tensor for the next iteration's preprocess.

Source code in vllm/v1/worker/mamba_utils.py
def postprocess_mamba_align_gpu(
    *,
    bufs: "MambaBuffers",
    num_reqs: int,
    num_accepted_tokens_gpu: torch.Tensor,
    num_accepted_tokens_cpu_tensor: torch.Tensor,
    input_batch: GPUInputBatch,
    kv_cache_config: KVCacheConfig,
    forward_context: dict[str, Any],
    mamba_state_copy_funcs: tuple[MambaStateCopyFunc, ...],
) -> None:
    """GPU-side mamba postprocess for spec decode + hybrid + align mode.

    Lazily binds the fused-kernel context to the persistent block tables and
    forward-context state pointers on the first call, runs the fused kernel,
    and async-copies the per-request accepted-token counts back to the input
    batch's CPU tensor for the next iteration's preprocess.
    """
    ctx = bufs.postprocess_align
    # Caller is responsible for gating on spec decode + hybrid; this assert is
    # a tripwire if those gates ever drift apart.
    assert ctx is not None
    assert ctx.mamba_state_idx_buf is not None
    assert ctx.num_scheduled_tokens_buf is not None
    assert ctx.num_computed_tokens_buf is not None
    assert ctx.num_draft_tokens_buf is not None

    if not ctx.is_initialized:
        ctx.initialize_from_forward_context(
            kv_cache_config,
            forward_context,
            mamba_state_copy_funcs,
            [
                input_batch.block_table[gid].get_device_tensor(num_reqs)
                for gid in ctx.mamba_group_ids
            ],
        )

    ctx.run_fused_postprocess(
        num_reqs=num_reqs,
        num_accepted_tokens_gpu=num_accepted_tokens_gpu,
        mamba_state_idx_gpu=ctx.mamba_state_idx_buf.gpu,
        num_scheduled_tokens_gpu=ctx.num_scheduled_tokens_buf.gpu,
        num_computed_tokens_gpu=ctx.num_computed_tokens_buf.gpu,
        num_draft_tokens_gpu=ctx.num_draft_tokens_buf.gpu,
    )

    # ``num_accepted_tokens_out`` is pre-initialized from
    # ``num_accepted_tokens_gpu``; the kernel only overwrites entries to 1
    # when src_block_idx == dest_block_idx (copy within the same block), so
    # the original count is preserved for everyone else.
    num_accepted_tokens_cpu_tensor[:num_reqs].copy_(
        ctx.num_accepted_tokens_out[:num_reqs], non_blocking=True
    )

postprocess_mamba_all(scheduler_output, kv_cache_config, input_batch, requests, mamba_state_idx, num_spec_tokens, num_reqs)

All-mode postprocess (only meaningful with num_spec_tokens > 0): record per-request the block index of the last token scheduled this step, so the next step can anchor its in-place writes when accepted drafts leave the sequence at a non-block-aligned position.

Source code in vllm/v1/worker/mamba_utils.py
def postprocess_mamba_all(
    scheduler_output: SchedulerOutput,
    kv_cache_config: KVCacheConfig,
    input_batch: GPUInputBatch,
    requests: dict[str, CachedRequestState],
    mamba_state_idx: dict[str, int],
    num_spec_tokens: int,
    num_reqs: int,
):
    """All-mode postprocess (only meaningful with num_spec_tokens > 0):
    record per-request the block index of the last token scheduled this
    step, so the next step can anchor its in-place writes when accepted
    drafts leave the sequence at a non-block-aligned position.
    """
    if num_spec_tokens <= 0:
        return
    _, mamba_spec = get_mamba_groups(kv_cache_config)
    block_size = mamba_spec.block_size
    full_decode_len = 1 + num_spec_tokens
    scheduled = scheduler_output.num_scheduled_tokens
    for req_id in input_batch.req_ids[:num_reqs]:
        num_query = scheduled.get(req_id, 0)
        if num_query == full_decode_len:
            req = requests[req_id]
            seq_len = req.num_computed_tokens + num_query
            mamba_state_idx[req_id] = max(0, (seq_len - 1) // block_size)
        else:
            mamba_state_idx.pop(req_id, None)

postprocess_mamba_fused_kernel(num_accepted_tokens_ptr, mamba_state_idx_ptr, num_scheduled_tokens_ptr, num_computed_tokens_ptr, num_draft_tokens_ptr, block_table_ptrs_ptr, block_table_stride_req, state_base_addrs_ptr, state_block_strides_ptr, state_elem_sizes_ptr, state_inner_sizes_ptr, state_conv_widths_ptr, state_group_indices_ptr, state_dim_row_count_ptr, state_dim_row_stride_ptr, num_accepted_tokens_out_ptr, idx_mapping_ptr, num_reqs, block_size, COPY_BLOCK_SIZE, CONV_STATE_DIM_FIRST, HAS_IDX_MAPPING=False, PRECOMPUTED_NEW_COMPUTED=False)

Fused GPU kernel for postprocess_mamba that computes decisions AND performs mamba state copies without any CPU-GPU synchronization.

Grid: (num_reqs, num_layers * num_state_types) - program_id(0) = request/batch index - program_id(1) = state_idx (flattened index into layer/state_type metadata)

Note: num_layers and num_state_types are not passed as kernel parameters because the kernel indexes directly into pre-flattened metadata arrays using program_id(1). The grid dimensions encode the total state count.

Source code in vllm/v1/worker/mamba_utils.py
@triton.jit
def postprocess_mamba_fused_kernel(
    # Decision inputs (per-request)
    num_accepted_tokens_ptr,
    mamba_state_idx_ptr,
    num_scheduled_tokens_ptr,
    num_computed_tokens_ptr,
    num_draft_tokens_ptr,
    # Per-group block table base addresses: int64[num_groups]. Each entry is
    # the data_ptr of that group's persistent [max_reqs, max_blocks] int32
    # block table.
    block_table_ptrs_ptr,
    block_table_stride_req: tl.int64,  # stride between requests (in elements)
    # Mamba state metadata (per-layer, per-state-type)
    # These are 1D arrays indexed by (layer_idx * num_state_types + state_type_idx)
    state_base_addrs_ptr,  # base address of each state tensor
    state_block_strides_ptr,  # bytes per block for each state
    state_elem_sizes_ptr,  # element size for each state
    state_inner_sizes_ptr,  # number of elements in inner dimensions
    state_conv_widths_ptr,  # conv width for conv states (0 for temporal)
    state_group_indices_ptr,  # maps state_idx to group index in block table
    # DS conv row metadata. Zero keeps the single-region copy path.
    state_dim_row_count_ptr,  # int32: per-block dim row count for DS conv
    state_dim_row_stride_ptr,  # int64: bytes between rows for DS conv
    # Output: num_accepted_tokens update (for src==dst case)
    num_accepted_tokens_out_ptr,
    # Optional: batch_idx -> req_idx mapping (V2 model runner / PP). The
    # per-request decision arrays are in req-state-slot order; the block table
    # is in batch order, so HAS_IDX_MAPPING splits the two indexings.
    idx_mapping_ptr,
    # Runtime parameter (varies per batch - NOT constexpr to avoid recompilation)
    num_reqs,
    # Compile-time constants (fixed after model initialization)
    # block_size: determined by model config, constant for all invocations
    block_size: tl.constexpr,
    # COPY_BLOCK_SIZE: fixed tuning parameter for memory copy loop
    COPY_BLOCK_SIZE: tl.constexpr,
    CONV_STATE_DIM_FIRST: tl.constexpr,
    # HAS_IDX_MAPPING: when True, program_id(0) is a batch index resolved to a
    # req-state slot via idx_mapping_ptr (V2). When False, it is the req index.
    HAS_IDX_MAPPING: tl.constexpr = False,
    # PRECOMPUTED_NEW_COMPUTED: when True, num_computed_tokens_ptr already holds
    # the post-step new_num_computed value (V2 supplies the advanced count).
    PRECOMPUTED_NEW_COMPUTED: tl.constexpr = False,
):
    """
    Fused GPU kernel for postprocess_mamba that computes decisions AND performs
    mamba state copies without any CPU-GPU synchronization.

    Grid: (num_reqs, num_layers * num_state_types)
    - program_id(0) = request/batch index
    - program_id(1) = state_idx (flattened index into layer/state_type metadata)

    Note: num_layers and num_state_types are not passed as kernel parameters
    because the kernel indexes directly into pre-flattened metadata arrays
    using program_id(1). The grid dimensions encode the total state count.
    """
    batch_idx = tl.program_id(0)
    state_idx = tl.program_id(1)

    # Bounds check
    if batch_idx >= num_reqs:
        return

    if HAS_IDX_MAPPING:
        req_idx = tl.load(idx_mapping_ptr + batch_idx)
        if req_idx < 0:
            return
    else:
        req_idx = batch_idx

    # Compute decision logic (mirrors postprocess_mamba Python reference)
    num_accepted = tl.load(num_accepted_tokens_ptr + req_idx)
    src_block_idx = tl.load(mamba_state_idx_ptr + req_idx)

    if PRECOMPUTED_NEW_COMPUTED:
        new_num_computed = tl.load(num_computed_tokens_ptr + req_idx)
        num_tokens_running_state = new_num_computed - num_accepted + 1
    else:
        num_scheduled = tl.load(num_scheduled_tokens_ptr + req_idx)
        num_computed = tl.load(num_computed_tokens_ptr + req_idx)
        num_draft = tl.load(num_draft_tokens_ptr + req_idx)
        num_tokens_running_state = num_computed + num_scheduled - num_draft
        new_num_computed = num_tokens_running_state + num_accepted - 1

    aligned_new_computed = (new_num_computed // block_size) * block_size

    needs_copy = aligned_new_computed >= num_tokens_running_state

    if not needs_copy:
        return

    # Compute copy parameters
    accept_token_bias = aligned_new_computed - num_tokens_running_state
    dest_block_idx = aligned_new_computed // block_size - 1

    # Update accepted-token count before early exits (per-request, so only
    # state_idx == 0 writes). V2 updates in place; V1 writes the _out buffer.
    if src_block_idx == dest_block_idx and state_idx == 0:
        if HAS_IDX_MAPPING:
            tl.store(num_accepted_tokens_ptr + req_idx, 1)
        else:
            tl.store(num_accepted_tokens_out_ptr + req_idx, 1)

    # Skip no-op self-copy.
    if src_block_idx == dest_block_idx and accept_token_bias == 0:
        return

    bt_row_idx = batch_idx if HAS_IDX_MAPPING else req_idx
    _copy_mamba_state_block(
        state_idx,
        bt_row_idx,
        src_block_idx,
        dest_block_idx,
        accept_token_bias,
        block_table_ptrs_ptr,
        block_table_stride_req,
        state_base_addrs_ptr,
        state_block_strides_ptr,
        state_elem_sizes_ptr,
        state_inner_sizes_ptr,
        state_conv_widths_ptr,
        state_group_indices_ptr,
        state_dim_row_count_ptr,
        state_dim_row_stride_ptr,
        COPY_BLOCK_SIZE,
        CONV_STATE_DIM_FIRST,
    )

precopy_mamba_align_fused_kernel(mamba_state_idx_ptr, src_col_ptr, token_bias_ptr, block_table_ptrs_ptr, block_table_stride_req, state_base_addrs_ptr, state_block_strides_ptr, state_elem_sizes_ptr, state_inner_sizes_ptr, state_conv_widths_ptr, state_group_indices_ptr, state_dim_row_count_ptr, state_dim_row_stride_ptr, idx_mapping_ptr, num_reqs, COPY_BLOCK_SIZE, CONV_STATE_DIM_FIRST)

Pre-copy mamba "align" state across block boundaries on the V2 runner.

Before the forward pass, copy each request's last SSM/conv state from its previous block column into the new window block column, so the kernels read the initial state from the write-side block as usual (V1 align semantics). Same per-(layer, state) copy semantics as postprocess_mamba_fused_kernel (shared _copy_mamba_state_block body, i.e. the V1 preprocess_mamba copy specs), but driven by the GPU-resident src columns so it needs no CPU-GPU sync (async-scheduling safe).

Grid: (num_reqs, num_layers * num_state_types); block tables are indexed by batch row, per-request state by req_idx via idx_mapping (V2 layout).

Source code in vllm/v1/worker/mamba_utils.py
@triton.jit
def precopy_mamba_align_fused_kernel(
    # Per-request-slot inputs (indexed by req_idx via idx_mapping), produced by
    # the V2 fused align preprocess kernel for the current step:
    mamba_state_idx_ptr,  # post-advance dst block column
    src_col_ptr,  # pre-advance src block column (-1 = fresh)
    token_bias_ptr,  # accepted-token bias = num_accepted - 1 (pre-reset)
    # Same flattened state-layout metadata as postprocess_mamba_fused_kernel
    block_table_ptrs_ptr,
    block_table_stride_req: tl.int64,
    state_base_addrs_ptr,
    state_block_strides_ptr,
    state_elem_sizes_ptr,
    state_inner_sizes_ptr,
    state_conv_widths_ptr,
    state_group_indices_ptr,
    state_dim_row_count_ptr,
    state_dim_row_stride_ptr,
    idx_mapping_ptr,  # [num_reqs] batch_idx -> req_state_idx (-1 to skip)
    num_reqs,
    COPY_BLOCK_SIZE: tl.constexpr,
    CONV_STATE_DIM_FIRST: tl.constexpr,
):
    """Pre-copy mamba "align" state across block boundaries on the V2 runner.

    Before the forward pass, copy each request's last SSM/conv state from its
    previous block column into the new window block column, so the kernels read
    the initial state from the write-side block as usual (V1 align semantics).
    Same per-(layer, state) copy semantics as ``postprocess_mamba_fused_kernel``
    (shared ``_copy_mamba_state_block`` body, i.e. the V1 ``preprocess_mamba``
    copy specs), but driven by the GPU-resident src columns so it needs no
    CPU-GPU sync (async-scheduling safe).

    Grid: (num_reqs, num_layers * num_state_types); block tables are indexed by
    batch row, per-request state by req_idx via idx_mapping (V2 layout).
    """
    batch_idx = tl.program_id(0)
    state_idx = tl.program_id(1)
    if batch_idx >= num_reqs:
        return
    req_idx = tl.load(idx_mapping_ptr + batch_idx)
    if req_idx < 0:
        return

    src_col = tl.load(src_col_ptr + req_idx)
    dst_col = tl.load(mamba_state_idx_ptr + req_idx)
    # Fresh state, or still writing the same block: kernels locate the initial
    # state in-block via num_accepted (preserved when no boundary is crossed),
    # so there is nothing to copy.
    if src_col < 0 or src_col == dst_col:
        return

    token_bias = tl.load(token_bias_ptr + req_idx)
    _copy_mamba_state_block(
        state_idx,
        batch_idx,
        src_col,
        dst_col,
        token_bias,
        block_table_ptrs_ptr,
        block_table_stride_req,
        state_base_addrs_ptr,
        state_block_strides_ptr,
        state_elem_sizes_ptr,
        state_inner_sizes_ptr,
        state_conv_widths_ptr,
        state_group_indices_ptr,
        state_dim_row_count_ptr,
        state_dim_row_stride_ptr,
        COPY_BLOCK_SIZE,
        CONV_STATE_DIM_FIRST,
    )

preprocess_mamba(scheduler_output, kv_cache_config, cache_config, mamba_state_idx, input_batch, requests, forward_context, mamba_state_copy_funcs, copy_bufs)

Copy the mamba state of previous step to the last (1 + num_speculative_blocks) block.

Source code in vllm/v1/worker/mamba_utils.py
def preprocess_mamba(
    scheduler_output: SchedulerOutput,
    kv_cache_config: KVCacheConfig,
    cache_config: CacheConfig,
    mamba_state_idx: dict[str, int],
    input_batch: GPUInputBatch,
    requests: dict[str, CachedRequestState],
    forward_context: dict[str, Any],
    mamba_state_copy_funcs: tuple[MambaStateCopyFunc, ...],
    copy_bufs: MambaCopyBuffers,
):
    """
    Copy the mamba state of previous step to the last
    (1 + num_speculative_blocks) block.
    """
    mamba_group_ids = copy_bufs.mamba_group_ids
    mamba_spec = copy_bufs.mamba_spec
    num_speculative_blocks = mamba_spec.num_speculative_blocks
    # TODO(Chen): we need to optimize this function a lot
    assert cache_config.enable_prefix_caching
    block_size = mamba_spec.block_size
    cleanup_mamba_state_idx(scheduler_output, mamba_state_idx)

    copy_bufs.offset = 0
    for i, req_id in enumerate(input_batch.req_ids):
        req_state = requests[req_id]
        prev_state_idx = mamba_state_idx.get(req_id)
        if prev_state_idx is None:
            # new / resumed request, no previous state
            # if num_computed_tokens is 0, prev_state_idx will be -1
            prev_state_idx = (req_state.num_computed_tokens - 1) // block_size

        num_scheduled_tokens = scheduler_output.num_scheduled_tokens[req_id]
        num_blocks: int = (
            cdiv(req_state.num_computed_tokens + num_scheduled_tokens, block_size)
            + num_speculative_blocks
        )

        # We always save the current running state at the last
        # (1 + num_speculative_blocks) block.
        # A corner case worth mention here: assume we have block_size = 4 and
        # num_speculative_tokens = 2. The request is [A, B, C] and contains 2 draft
        # tokens [draft 1, draft 2]. Then we will have:
        # Block 0: [A, B, C, draft 1]
        # Block 1: [draft 2, TOFILL, TOFILL, TOFILL]
        # Block 2: speculative block
        # Block 3: speculative block
        # And use block 1 to save the running state.
        curr_state_idx = num_blocks - 1 - num_speculative_blocks
        mamba_state_idx[req_id] = curr_state_idx
        if prev_state_idx != -1 and prev_state_idx != curr_state_idx:
            collect_mamba_copy_meta(
                copy_bufs,
                kv_cache_config,
                mamba_state_copy_funcs,
                mamba_group_ids,
                prev_state_idx,
                curr_state_idx,
                input_batch.num_accepted_tokens_cpu[i] - 1,
                req_state,
                forward_context,
            )
            input_batch.num_accepted_tokens_cpu[i] = 1
    do_mamba_copy_block(copy_bufs)

preprocess_mamba_align_fused_kernel(idx_mapping_ptr, state_idx_ptr, num_computed_tokens_ptr, query_start_loc_ptr, num_accepted_tokens_ptr, src_col_ptr, src_off_ptr, num_reqs, BLOCK_SIZE, MAMBA_BLOCK_SIZE)

Fused align preprocess: emit the pre-copy src column/offset AND advance state_idx (with accepted-token reset) in a single launch (V2 align).

Per batch_idx (0..num_reqs-1), resolving req slot via idx_mapping: 1. Read pre-advance state_idx and num_accepted (last step's values). 2. Store the pre-copy src columns for precopy_mamba_align_fused_kernel: - src_col = state_idx (the previous running block column) - src_off = max(num_accepted - 1, 0) (the accepted-token bias) 3. Advance state_idx to the new running block, and reset num_accepted to 1 when a block boundary is crossed (so the migrated state, now at the start of the new block, is read with the neutral bias).

Source code in vllm/v1/worker/mamba_utils.py
@triton.jit
def preprocess_mamba_align_fused_kernel(
    idx_mapping_ptr,
    state_idx_ptr,
    num_computed_tokens_ptr,
    query_start_loc_ptr,
    num_accepted_tokens_ptr,
    src_col_ptr,
    src_off_ptr,
    num_reqs,
    BLOCK_SIZE: tl.constexpr,
    MAMBA_BLOCK_SIZE: tl.constexpr,
):
    """Fused align preprocess: emit the pre-copy src column/offset AND advance
    state_idx (with accepted-token reset) in a single launch (V2 align).

    Per batch_idx (0..num_reqs-1), resolving req slot via idx_mapping:
      1. Read pre-advance state_idx and num_accepted (last step's values).
      2. Store the pre-copy src columns for ``precopy_mamba_align_fused_kernel``:
         - src_col = state_idx (the previous running block column)
         - src_off = max(num_accepted - 1, 0) (the accepted-token bias)
      3. Advance state_idx to the new running block, and reset num_accepted to 1
         when a block boundary is crossed (so the migrated state, now at the
         start of the new block, is read with the neutral bias).
    """
    offsets = tl.program_id(0) * BLOCK_SIZE + tl.arange(0, BLOCK_SIZE)
    mask = offsets < num_reqs
    req_indices = tl.load(idx_mapping_ptr + offsets, mask=mask, other=0)

    state_idx = tl.load(state_idx_ptr + req_indices, mask=mask, other=-1)
    num_accepted = tl.load(num_accepted_tokens_ptr + req_indices, mask=mask, other=1)

    src_off = tl.maximum(num_accepted - 1, 0)
    tl.store(src_col_ptr + req_indices, state_idx, mask=mask)
    tl.store(src_off_ptr + req_indices, src_off, mask=mask)

    num_computed = tl.load(num_computed_tokens_ptr + req_indices, mask=mask, other=0)
    query_start = tl.load(query_start_loc_ptr + offsets, mask=mask, other=0)
    query_end = tl.load(query_start_loc_ptr + offsets + 1, mask=mask, other=0)
    computed_after = num_computed + query_end - query_start
    new_state_idx = (computed_after + MAMBA_BLOCK_SIZE - 1) // MAMBA_BLOCK_SIZE - 1
    tl.store(state_idx_ptr + req_indices, new_state_idx, mask=mask)
    should_reset = (state_idx >= 0) & (state_idx != new_state_idx)
    tl.store(num_accepted_tokens_ptr + req_indices, 1, mask=mask & should_reset)

stage_mamba_state_idx_to_gpu(mamba_state_idx, req_ids, num_reqs, gpu_buf)

Materialize mamba_state_idx into gpu_buf and copy to GPU.

Walks req_ids[:num_reqs] in batch order, writing each request's block index into the buffer's pinned numpy view, then issues a non-blocking H→D copy. The fused kernel indexes the resulting GPU tensor by req_idx.

Invariant: preprocess_mamba must have run first for the same batch so that every req_ids[i] has an entry in mamba_state_idx.

Source code in vllm/v1/worker/mamba_utils.py
def stage_mamba_state_idx_to_gpu(
    mamba_state_idx: dict[str, int],
    req_ids: list[str],
    num_reqs: int,
    gpu_buf: CpuGpuBuffer,
) -> None:
    """Materialize ``mamba_state_idx`` into ``gpu_buf`` and copy to GPU.

    Walks ``req_ids[:num_reqs]`` in batch order, writing each request's block
    index into the buffer's pinned numpy view, then issues a non-blocking H→D
    copy. The fused kernel indexes the resulting GPU tensor by ``req_idx``.

    Invariant: ``preprocess_mamba`` must have run first for the same batch so
    that every ``req_ids[i]`` has an entry in ``mamba_state_idx``.
    """
    np_view = gpu_buf.np
    for i in range(num_reqs):
        req_id = req_ids[i]
        state_idx = mamba_state_idx.get(req_id)
        assert state_idx is not None, (
            f"mamba_state_idx missing entry for {req_id!r}; "
            "preprocess_mamba must run before stage_mamba_state_idx_to_gpu"
        )
        np_view[i] = state_idx
    gpu_buf.copy_to_gpu(num_reqs)

stage_postprocess_inputs_to_gpu(ctx, scheduler_output, req_ids, num_reqs, requests, mamba_state_idx)

Stage all per-request inputs the fused mamba postprocess kernel reads.

Bundles stage_mamba_state_idx_to_gpu and stage_postprocess_metadata_to_gpu into a single call so the runner has one entry point for postprocess staging. Buffers live on ctx and only exist when the postprocess kernel is enabled.

Source code in vllm/v1/worker/mamba_utils.py
def stage_postprocess_inputs_to_gpu(
    ctx: MambaSpecDecodeGPUContext,
    scheduler_output: SchedulerOutput,
    req_ids: list[str],
    num_reqs: int,
    requests: dict[str, CachedRequestState],
    mamba_state_idx: dict[str, int],
) -> None:
    """Stage all per-request inputs the fused mamba postprocess kernel reads.

    Bundles ``stage_mamba_state_idx_to_gpu`` and
    ``stage_postprocess_metadata_to_gpu`` into a single call so the runner
    has one entry point for postprocess staging. Buffers live on ``ctx``
    and only exist when the postprocess kernel is enabled.
    """
    assert ctx.mamba_state_idx_buf is not None
    assert ctx.num_scheduled_tokens_buf is not None
    assert ctx.num_computed_tokens_buf is not None
    assert ctx.num_draft_tokens_buf is not None
    stage_mamba_state_idx_to_gpu(
        mamba_state_idx,
        req_ids,
        num_reqs,
        ctx.mamba_state_idx_buf,
    )
    stage_postprocess_metadata_to_gpu(
        scheduler_output,
        req_ids,
        num_reqs,
        requests,
        ctx.num_scheduled_tokens_buf,
        ctx.num_computed_tokens_buf,
        ctx.num_draft_tokens_buf,
    )

stage_postprocess_metadata_to_gpu(scheduler_output, req_ids, num_reqs, requests, num_scheduled_tokens_buf, num_computed_tokens_buf, num_draft_tokens_buf)

Stage per-request postprocess metadata into GPU buffers (non-blocking).

Walks req_ids[:num_reqs] in batch order and writes each request's scheduled/computed/draft token counts into the matching pinned numpy views, then issues three non-blocking H→D copies. These values don't change between _prepare_inputs and _update_states_after_model_execute. The fused postprocess kernel indexes the resulting GPU tensors by req_idx.

Source code in vllm/v1/worker/mamba_utils.py
def stage_postprocess_metadata_to_gpu(
    scheduler_output: SchedulerOutput,
    req_ids: list[str],
    num_reqs: int,
    requests: dict[str, CachedRequestState],
    num_scheduled_tokens_buf: CpuGpuBuffer,
    num_computed_tokens_buf: CpuGpuBuffer,
    num_draft_tokens_buf: CpuGpuBuffer,
) -> None:
    """Stage per-request postprocess metadata into GPU buffers (non-blocking).

    Walks ``req_ids[:num_reqs]`` in batch order and writes each request's
    scheduled/computed/draft token counts into the matching pinned numpy
    views, then issues three non-blocking H→D copies. These values don't
    change between ``_prepare_inputs`` and ``_update_states_after_model_execute``.
    The fused postprocess kernel indexes the resulting GPU tensors
    by ``req_idx``.
    """
    scheduled_spec_tokens = scheduler_output.scheduled_spec_decode_tokens
    num_scheduled = scheduler_output.num_scheduled_tokens
    scheduled_np = num_scheduled_tokens_buf.np
    computed_np = num_computed_tokens_buf.np
    draft_np = num_draft_tokens_buf.np
    for i in range(num_reqs):
        req_id = req_ids[i]
        scheduled_np[i] = num_scheduled[req_id]
        computed_np[i] = requests[req_id].num_computed_tokens
        draft_np[i] = len(scheduled_spec_tokens.get(req_id, []))
    num_scheduled_tokens_buf.copy_to_gpu(num_reqs)
    num_computed_tokens_buf.copy_to_gpu(num_reqs)
    num_draft_tokens_buf.copy_to_gpu(num_reqs)