class MambaHybridModelState(DefaultModelState):
"""Model state for hybrid attention + Mamba / linear-attention models."""
def __init__(
self,
vllm_config: VllmConfig,
model: nn.Module,
encoder_cache: EncoderCache | None,
device: torch.device,
) -> None:
super().__init__(vllm_config, model, encoder_cache, device)
self.cache_config = vllm_config.cache_config
self.num_accepted_tokens_gpu = torch.ones(
self.max_num_reqs, dtype=torch.int32, device=self.device
)
# Pre-copy "align" prefix-cache state (V2). The migration of each
# request's mamba state across block boundaries runs as a fused GPU
# kernel reusing the postprocess copy machinery, so the per-step src
# columns and the running state_idx are kept GPU-resident.
self._align_mode = self.cache_config.mamba_cache_mode == "align"
if self._align_mode:
self._mamba_state_idx_gpu = torch.zeros(
self.max_num_reqs, dtype=torch.int32, device=self.device
)
self._mamba_src_col_gpu = torch.full(
(self.max_num_reqs,), -1, dtype=torch.int32, device=self.device
)
self._mamba_src_off_gpu = torch.zeros(
self.max_num_reqs, dtype=torch.int32, device=self.device
)
self._mamba_ctx: MambaSpecDecodeGPUContext | None = None
self._mamba_group_ids: list[int] = []
self._mamba_spec: MambaSpec | None = None
def add_request(self, req_index: int, new_req_data: NewRequestData) -> None:
super().add_request(req_index, new_req_data)
if self._align_mode:
# Seed the running state block from the resumed/prefilled position.
self._mamba_state_idx_gpu[req_index] = (
new_req_data.num_computed_tokens - 1
) // self.cache_config.block_size
self.num_accepted_tokens_gpu[req_index] = 1
def _get_mamba_group_info(
self, kv_cache_config: KVCacheConfig
) -> tuple[list[int], MambaSpec]:
if self._mamba_spec is None:
group_ids: list[int] = []
specs: list[MambaSpec] = []
for i, group in enumerate(kv_cache_config.kv_cache_groups):
spec = group.kv_cache_spec
if isinstance(spec, MambaSpec):
group_ids.append(i)
specs.append(spec)
assert specs, "no mamba layers in the model"
assert all(specs[0] == s for s in specs)
self._mamba_group_ids = group_ids
self._mamba_spec = specs[0]
return self._mamba_group_ids, self._mamba_spec
def _ensure_align_ctx(
self,
kv_cache_config: KVCacheConfig,
mamba_group_ids: list[int],
block_tables: tuple[torch.Tensor, ...],
) -> MambaSpecDecodeGPUContext:
if self._mamba_ctx is None:
copy_funcs = self.model.get_mamba_state_copy_func()
# The fused copy kernels shift conv windows assuming the SD layout;
# the DS layout cannot express a >0 spec-decode shift as a single
# contiguous copy (mirrors get_conv_copy_spec's NotImplementedError).
if get_conv_copy_spec in copy_funcs and is_conv_state_dim_first():
assert self.vllm_config.speculative_config is None, (
"DS conv state layout does not support mamba align state "
"copies with speculative decoding"
)
self._mamba_ctx = MambaSpecDecodeGPUContext.create(
max_num_reqs=self.max_num_reqs,
kv_cache_config=kv_cache_config,
num_state_types=len(copy_funcs),
device=self.device,
make_buffer=lambda n, dtype: CpuGpuBuffer(
n, dtype=dtype, device=self.device
),
)
ctx = self._mamba_ctx
if not ctx.is_initialized:
forward_context = self.vllm_config.compilation_config.static_forward_context
# block_tables are batch-order slices of the persistent
# input_block_tables (stable data_ptr), so the metadata is captured
# once here and reused across steps.
ctx.initialize_from_forward_context(
kv_cache_config,
forward_context,
self.model.get_mamba_state_copy_func(),
[block_tables[gid] for gid in mamba_group_ids],
)
return ctx
def preprocess_state(
self,
input_batch: InputBatch,
block_tables: tuple[torch.Tensor, ...],
kv_cache_config: KVCacheConfig,
num_computed_tokens: torch.Tensor,
) -> None:
"""Migrate each request's mamba state across block boundaries before the
forward (V1 align semantics, done on GPU). Runs on real batches only
(dummy DP/profiling runs skip preprocess_state), and before
``prepare_attn`` gathers ``num_accepted_tokens``, so the boundary reset
is visible to the forward kernels.
"""
if not self._align_mode:
return
num_reqs = input_batch.num_reqs
if num_reqs == 0:
return
mamba_group_ids, mamba_spec = self._get_mamba_group_info(kv_cache_config)
ctx = self._ensure_align_ctx(kv_cache_config, mamba_group_ids, block_tables)
# The state-advance + pre-copy kernels run every step; they fast-exit per
# request when src_col < 0 or src_col == dst_col, so no copy happens on
# steps that don't cross a block boundary. (Skipping the launch entirely
# would need a V1-style async-D2H of the actual num_computed, since
# num_computed_tokens_np is an optimistic mirror under async scheduling;
# the launch cost is ~0.3% of TPOT, so the GPU fast-exit suffices.)
block = 256
grid = (triton.cdiv(num_reqs, block),)
preprocess_mamba_align_fused_kernel[grid](
input_batch.idx_mapping,
self._mamba_state_idx_gpu,
num_computed_tokens,
input_batch.query_start_loc,
self.num_accepted_tokens_gpu,
self._mamba_src_col_gpu,
self._mamba_src_off_gpu,
num_reqs,
BLOCK_SIZE=block,
MAMBA_BLOCK_SIZE=mamba_spec.block_size,
)
ctx.run_fused_precopy(
num_reqs,
self._mamba_state_idx_gpu,
self._mamba_src_col_gpu,
self._mamba_src_off_gpu,
input_batch.idx_mapping,
)
def prepare_attn(
self,
input_batch: InputBatch,
cudagraph_mode: CUDAGraphMode,
block_tables: tuple[torch.Tensor, ...],
slot_mappings: torch.Tensor,
attn_groups: list[list[AttentionGroup]],
kv_cache_config: KVCacheConfig,
for_capture: bool = False,
) -> dict[str, Any]:
if cudagraph_mode == CUDAGraphMode.FULL:
num_reqs = input_batch.num_reqs_after_padding
num_tokens = input_batch.num_tokens_after_padding
else:
num_reqs = input_batch.num_reqs
num_tokens = input_batch.num_tokens
query_start_loc_cpu = torch.from_numpy(input_batch.query_start_loc_np)
max_query_len = input_batch.num_scheduled_tokens.max().item()
seq_lens_cpu_upper_bound = input_batch.seq_lens_cpu_upper_bound
if for_capture:
# Capture with worst-case max_seq_len so the graph is valid at any replay.
max_seq_len = self.max_model_len
else:
max_seq_len = seq_lens_cpu_upper_bound[:num_reqs].max().item()
is_prefilling = torch.zeros(num_reqs, dtype=torch.bool, device="cpu")
is_prefilling[: input_batch.num_reqs] = torch.from_numpy(
input_batch.is_prefilling_np
)
# During CUDAGraph capture, num_decode_draft_tokens_cpu and num_accepted_tokens
# are created by attn_metadata_builder.build_for_cudagraph_capture, so we only
# compute them during actual (non-capture) forward execution.
num_accepted_tokens = None
num_decode_draft_tokens_cpu = None
if not for_capture:
num_accepted_tokens = self.num_accepted_tokens_gpu.new_ones(num_reqs)
num_accepted_tokens[: input_batch.num_reqs] = self.num_accepted_tokens_gpu[
input_batch.idx_mapping
]
# GDN uses >= 0 to select spec-decode rows, so non-decode rows
# need the -1 sentinel rather than a raw zero draft count.
num_decode_draft_tokens_np = np.full(num_reqs, -1, dtype=np.int32)
if input_batch.num_draft_tokens_per_req is not None:
has_draft_tokens = input_batch.num_draft_tokens_per_req > 0
spec_decode_mask = has_draft_tokens & ~input_batch.is_prefilling_np
num_decode_draft_tokens_np[: input_batch.num_reqs] = np.where(
spec_decode_mask, input_batch.num_draft_tokens_per_req, -1
)
num_decode_draft_tokens_cpu = torch.from_numpy(num_decode_draft_tokens_np)
mamba_attn_metadata = MambaHybridAttnMetadata(
is_prefilling=is_prefilling,
num_accepted_tokens=num_accepted_tokens,
num_decode_draft_tokens_cpu=num_decode_draft_tokens_cpu,
)
return build_attn_metadata(
attn_groups=attn_groups,
num_reqs=num_reqs,
num_tokens=num_tokens,
query_start_loc_gpu=input_batch.query_start_loc,
query_start_loc_cpu=query_start_loc_cpu,
max_query_len=max_query_len,
seq_lens=input_batch.seq_lens,
max_seq_len=max_seq_len,
block_tables=block_tables,
slot_mappings=slot_mappings,
kv_cache_config=kv_cache_config,
seq_lens_cpu_upper_bound=seq_lens_cpu_upper_bound,
dcp_local_seq_lens=input_batch.dcp_local_seq_lens,
model_specific_attn_metadata=mamba_attn_metadata,
for_cudagraph_capture=for_capture,
rswa_prefix_lens=input_batch.prompt_lens,
)
def postprocess_state(
self,
idx_mapping: torch.Tensor,
num_sampled: torch.Tensor | int,
num_computed_tokens: torch.Tensor | None = None,
) -> None:
# Chunked prefill does not sample a token, so num_sampled can be 0.
# Mamba treats num_accepted_tokens=1 as the neutral non-spec value.
if not isinstance(num_sampled, int):
# idx_mapping may contain -1 sentinels (filtered rows) under PP; the
# kernel skips them rather than scattering with a host-side gather.
n = idx_mapping.shape[0]
if n:
_scatter_num_accepted_kernel[(n,)](
idx_mapping, num_sampled, self.num_accepted_tokens_gpu
)
else:
# Fill with single value.
self.num_accepted_tokens_gpu.index_fill_(
0, idx_mapping, max(num_sampled, 1)
)
# Align: save the running state to the block-aligned position when
# spec-decode acceptance leaves the sequence non-block-aligned (mirrors
# the V1 align postprocess). num_computed_tokens already holds the
# post-step advanced count.
if (
self._align_mode
and num_computed_tokens is not None
and self._mamba_ctx is not None
):
num_reqs = idx_mapping.shape[0]
if num_reqs:
self._mamba_ctx.run_fused_postprocess_align(
num_reqs,
self.num_accepted_tokens_gpu,
self._mamba_state_idx_gpu,
num_computed_tokens,
idx_mapping,
)