class AutoRegressiveSpeculator(DraftModelSpeculator):
def __init__(self, vllm_config: VllmConfig, device: torch.device):
super().__init__(vllm_config, device)
self.hidden_states = torch.zeros(
self.max_num_tokens, self.hidden_size, dtype=self.dtype, device=device
)
self.current_draft_step = torch.tensor(0, dtype=torch.int64, device=device)
self.last_token_indices = torch.zeros(
self.max_num_reqs, dtype=torch.int64, device=device
)
self.supports_mm_inputs = MULTIMODAL_REGISTRY.supports_multimodal_inputs(
self.draft_model_config
)
if self.supports_mm_inputs:
self.inputs_embeds = torch.zeros(
self.max_num_tokens, self.hidden_size, dtype=self.dtype, device=device
)
self.prefill_cudagraph_manager: PrefillSpeculatorCudaGraphManager | None = None
self.decode_cudagraph_manager: DecodeSpeculatorCudaGraphManager | None = None
@property
def advance_draft_positions(self) -> bool:
"""
Whether to increment positions and seq_lens between draft steps.
True for Eagle/standard MTP (each step produces new KV).
False for Gemma4 MTP (Q-only, shares target KV, constant positions).
"""
return True
@property
def model_returns_tuple(self) -> bool:
"""
Whether the draft model's forward() returns a tuple.
True: returns (last_hidden_states, hidden_states) — Eagle, Gemma4 MTP.
False: returns a single tensor used for both — standard MTP (DeepSeek).
"""
return True
def init_cudagraph_manager(self, cudagraph_mode: CUDAGraphMode) -> None:
# Initialize cudagraph manager for draft prefill (draft position 0).
self.prefill_cudagraph_manager = PrefillSpeculatorCudaGraphManager(
self.vllm_config,
self.device,
cudagraph_mode,
self.num_speculative_steps + 1,
)
# PIECEWISE cudagraphs are not supported for draft decodes.
if cudagraph_mode.decode_mode() == CUDAGraphMode.FULL:
cudagraph_mode = CUDAGraphMode.FULL_DECODE_ONLY
else:
cudagraph_mode = CUDAGraphMode.NONE
# Initialize cudagraph manager for draft decodes (draft positions > 0).
self.decode_cudagraph_manager = DecodeSpeculatorCudaGraphManager(
self.vllm_config,
self.device,
cudagraph_mode,
decode_query_len=1,
)
def capture(
self,
attn_states: dict[BatchExecutionDescriptor, AttentionStatePair],
) -> None:
logger.info("Capturing model for speculator...")
# Reset indices to zeros to prevent stale values from prior
# dummy runs to cause out-of-bounds indexing during capture.
self.last_token_indices.zero_()
# Capture the prefill routine (model forward + compute_logits +
# sample).
# For FULL graphs, the entire routine is recorded as one graph.
# For PIECEWISE, only the model's compiled regions are captured
# and the rest (compute_logits, gumbel_sample) runs eagerly.
assert self.prefill_cudagraph_manager is not None
if self.prefill_cudagraph_manager.use_breakable_cg:
self.prefill_cudagraph_manager.init_breakable_cg_runner(self.model)
self.prefill_cudagraph_manager.capture(
self._prefill,
attn_states,
progress_bar_desc="Capturing prefill CUDA graphs",
)
if self.num_speculative_steps == 1:
return
# Capture the decode draft generation routine (model forward +
# sample + update_draft_inputs) for a single
# step.
assert self.decode_cudagraph_manager is not None
self.decode_cudagraph_manager.capture(
self._generate_draft,
self.model_state,
self.input_buffers,
self.block_tables,
self.attn_groups,
self.kv_cache_config,
progress_bar_desc="Capturing decode CUDA graphs",
)
@torch.inference_mode()
def propose(
self,
input_batch: InputBatch,
attn_metadata: dict[str, Any],
slot_mappings: dict[str, torch.Tensor],
# [num_tokens, hidden_size]
last_hidden_states: torch.Tensor,
# num_layers x [num_tokens, hidden_size]
aux_hidden_states: list[torch.Tensor] | None,
# [num_reqs]
num_sampled: torch.Tensor,
# [num_reqs]
num_rejected: torch.Tensor,
# [max_num_reqs]
last_sampled: torch.Tensor,
# [max_num_reqs]
next_prefill_tokens: torch.Tensor,
# [max_num_reqs]
temperature: torch.Tensor,
# [max_num_reqs]
seeds: torch.Tensor,
num_tokens_across_dp: torch.Tensor | None = None,
dummy_run: bool = False,
skip_attn_for_dummy_run: bool = False,
mm_inputs: tuple[list[torch.Tensor], torch.Tensor] | None = None,
is_profile: bool = False,
) -> torch.Tensor:
num_tokens = input_batch.num_tokens_after_padding
num_reqs = input_batch.num_reqs
max_query_len = input_batch.num_scheduled_tokens.max()
max_seq_len = input_batch.seq_lens_cpu_upper_bound[:num_reqs].max().item()
self.draft_max_seq_len = min(
max_seq_len + self.num_speculative_steps, self.max_model_len
)
# NOTE(woosuk): To avoid CPU-GPU synchronization without CPU knowing the
# number of rejected tokens, we maintain the size of input_ids and
# hidden_states the same as the target model's. This means, we pad each
# request's query length to include any rejected positions. By doing so,
# we can also reuse the attention metadata (e.g., query_start_loc,
# seq_lens) of the target model.
if aux_hidden_states:
assert self.method == "eagle3"
hidden_states = self.model.combine_hidden_states(
torch.cat(aux_hidden_states, dim=-1)
)
else:
hidden_states = last_hidden_states
self.hidden_states[:num_tokens].copy_(hidden_states)
self._copy_request_inputs(
num_reqs,
input_batch.idx_mapping,
temperature,
seeds,
)
# Get the input ids and last token indices for the speculator.
prepare_prefill_inputs(
self.last_token_indices,
self.current_draft_step,
self.input_buffers,
input_batch,
num_sampled,
num_rejected,
last_sampled,
next_prefill_tokens,
self.max_num_reqs,
)
# When all requests are decoding (no true prefills), each has
# num_speculative_steps + 1 tokens, enabling FULL graph replay.
uniform_token_count = get_uniform_token_count(
num_reqs,
# Use the actual number of tokens without padding added by
# the target model during FULL cudagraph.
input_batch.num_tokens,
max_query_len,
)
prefill_batch_desc, num_tokens_across_dp = dispatch_cg_and_sync_dp(
self.prefill_cudagraph_manager,
num_reqs,
num_tokens,
uniform_token_count,
dp_size=self.dp_size,
dp_rank=self.dp_rank,
need_eager=is_profile,
)
if prefill_batch_desc.cg_mode == CUDAGraphMode.FULL:
# Replay the full graph for draft prefill.
assert self.prefill_cudagraph_manager is not None
self.prefill_cudagraph_manager.run_fullgraph(prefill_batch_desc)
else:
# The target model's attention metadata and slot mappings
# can directly be used for draft prefill, because of the
# identical batch shape and KV cache layout.
self._prefill(
num_reqs,
prefill_batch_desc.num_tokens,
attn_metadata,
slot_mappings,
num_tokens_across_dp=num_tokens_across_dp,
cudagraph_runtime_mode=prefill_batch_desc.cg_mode,
mm_inputs=mm_inputs,
)
if self.num_speculative_steps == 1:
# Early exit.
return self.draft_tokens[:num_reqs, :1]
# Prepare the inputs for the decode steps.
prepare_decode_inputs(
self.draft_tokens[:num_reqs, 0],
input_batch.seq_lens,
num_rejected,
self.input_buffers,
self.max_model_len,
self.max_num_reqs,
advance_draft_positions=self.advance_draft_positions,
)
# Each request produces exactly 1 token per draft generation step,
# enabling FULL graph replay.
decode_batch_desc, num_tokens_across_dp = dispatch_cg_and_sync_dp(
self.decode_cudagraph_manager,
num_reqs,
num_reqs,
uniform_token_count=1,
dp_size=self.dp_size,
dp_rank=self.dp_rank,
need_eager=is_profile,
)
# Generate the remaining num_speculative_steps - 1 draft tokens.
self._multi_step_decode(
num_reqs,
dummy_run and skip_attn_for_dummy_run,
decode_batch_desc,
num_tokens_across_dp,
)
return self.draft_tokens[:num_reqs]
def sample_draft(
self,
hidden_states: torch.Tensor,
positions: torch.Tensor,
idx_mapping: torch.Tensor,
temperature: torch.Tensor,
seeds: torch.Tensor,
draft_step: torch.Tensor,
draft_logits: torch.Tensor | None,
) -> torch.Tensor:
logits = self.model.compute_logits(hidden_states)
if draft_logits is not None:
# NOTE(woosuk): We must add 1 to the positions to match the Gumbel noise
# used for draft and target sampling.
return gumbel_sample(
logits,
idx_mapping,
temperature,
seeds,
positions + 1,
apply_temperature=True,
output_processed_logits=draft_logits,
output_processed_logits_col=draft_step,
use_fp64=self.use_fp64_gumbel,
)
else:
return logits.argmax(dim=-1)
@torch.inference_mode()
def _run_model(
self,
num_tokens: int,
attn_metadata: dict[str, Any] | None,
slot_mappings: dict[str, torch.Tensor] | None,
num_tokens_across_dp: torch.Tensor | None,
cudagraph_runtime_mode: CUDAGraphMode = CUDAGraphMode.NONE,
mm_inputs: tuple[list[torch.Tensor], torch.Tensor] | None = None,
) -> tuple[torch.Tensor, torch.Tensor]:
batch_descriptor = BatchDescriptor(num_tokens=num_tokens)
with set_forward_context(
attn_metadata,
self.vllm_config,
num_tokens=num_tokens,
cudagraph_runtime_mode=cudagraph_runtime_mode,
num_tokens_across_dp=num_tokens_across_dp,
slot_mapping=slot_mappings,
batch_descriptor=batch_descriptor,
):
inputs_embeds = None
if self.supports_mm_inputs:
# Merge multimodal embeddings with input ids.
mm_embeds, is_mm_embed = mm_inputs or (None, None)
num_input_tokens = (
is_mm_embed.shape[0] if is_mm_embed is not None else num_tokens
)
self.inputs_embeds[:num_input_tokens] = self.model.embed_input_ids(
self.input_buffers.input_ids[:num_input_tokens],
multimodal_embeddings=mm_embeds,
is_multimodal=is_mm_embed,
)
inputs_embeds = self.inputs_embeds[:num_tokens]
model_inputs = dict(
input_ids=self.input_buffers.input_ids[:num_tokens],
positions=self.input_buffers.positions[:num_tokens],
hidden_states=self.hidden_states[:num_tokens],
inputs_embeds=inputs_embeds,
)
if cudagraph_runtime_mode == CUDAGraphMode.PIECEWISE:
# Draft prefill with PIECEWISE cudagraph (compiled PW or breakable),
# chosen inside run_pw_graph.
assert self.prefill_cudagraph_manager is not None
ret_hidden_states = self.prefill_cudagraph_manager.run_pw_graph(
self.model, model_inputs
)
else:
# Eager (NONE): call the raw model directly.
ret_hidden_states = self.model(**model_inputs)
if self.model_returns_tuple:
last_hidden_states, hidden_states = ret_hidden_states
else:
last_hidden_states = ret_hidden_states
hidden_states = ret_hidden_states
return last_hidden_states, hidden_states
def _prefill(
self,
num_reqs: int,
num_tokens: int,
attn_metadata: dict[str, Any] | None,
slot_mappings: dict[str, torch.Tensor] | None,
num_tokens_across_dp: torch.Tensor | None,
cudagraph_runtime_mode: CUDAGraphMode = CUDAGraphMode.NONE,
mm_inputs: tuple[list[torch.Tensor], torch.Tensor] | None = None,
) -> None:
last_token_indices = self.last_token_indices[:num_reqs]
positions = self.input_buffers.positions[last_token_indices]
idx_mapping = self.idx_mapping[:num_reqs]
last_hidden_states, hidden_states = self._run_model(
num_tokens,
attn_metadata,
slot_mappings,
num_tokens_across_dp=num_tokens_across_dp,
cudagraph_runtime_mode=cudagraph_runtime_mode,
mm_inputs=mm_inputs,
)
sample_hidden_states = last_hidden_states[last_token_indices]
self.draft_tokens[:num_reqs, 0] = self.sample_draft(
sample_hidden_states,
positions,
idx_mapping,
self.temperature,
self.seeds,
self.current_draft_step,
self.draft_logits,
)
self.hidden_states[:num_reqs] = hidden_states[last_token_indices]
self.input_buffers.positions[:num_reqs] = positions
def _multi_step_decode(
self,
num_reqs: int,
skip_attn: bool,
batch_desc: BatchExecutionDescriptor,
num_tokens_across_dp: torch.Tensor | None,
) -> None:
positions = self.input_buffers.positions[:num_reqs]
query_start_loc = self.input_buffers.query_start_loc[: num_reqs + 1]
idx_mapping = self.idx_mapping[:num_reqs]
attn_metadata = None
slot_mappings_by_layer = None
for step in range(1, self.num_speculative_steps):
# Rebuild every step when positions advance, or just once
# on the first step when positions are constant (Gemma4 MTP).
if not skip_attn and (self.advance_draft_positions or step == 1):
slot_mappings = self.block_tables.compute_slot_mappings(
idx_mapping,
query_start_loc,
positions,
batch_desc.num_tokens,
)
slot_mappings_by_layer = build_slot_mappings_by_layer(
slot_mappings, self.kv_cache_config
)
attn_metadata = self._build_draft_attn_metadata(
num_reqs=num_reqs,
num_reqs_padded=batch_desc.num_reqs or num_reqs,
num_tokens_padded=batch_desc.num_tokens,
)
# Update the current draft step.
self.current_draft_step.fill_(step)
# Generate draft tokens for the current step.
if batch_desc.cg_mode == CUDAGraphMode.FULL:
assert self.decode_cudagraph_manager is not None
self.decode_cudagraph_manager.run_fullgraph(batch_desc)
else:
self._generate_draft(
num_reqs,
batch_desc.num_tokens,
attn_metadata,
slot_mappings_by_layer,
num_tokens_across_dp=num_tokens_across_dp,
cudagraph_runtime_mode=batch_desc.cg_mode,
)
def _generate_draft(
self,
num_reqs: int,
num_tokens_padded: int,
attn_metadata: dict[str, Any] | None,
slot_mappings: dict[str, torch.Tensor] | None,
num_tokens_across_dp: torch.Tensor | None,
cudagraph_runtime_mode: CUDAGraphMode = CUDAGraphMode.NONE,
) -> None:
idx_mapping = self.idx_mapping[:num_reqs]
positions = self.input_buffers.positions[:num_reqs]
# Run the draft model forward pass.
last_hidden_states, hidden_states = self._run_model(
num_tokens_padded,
attn_metadata,
slot_mappings,
num_tokens_across_dp,
cudagraph_runtime_mode,
)
last_hidden_states = last_hidden_states[:num_reqs]
# Sample the draft tokens.
draft_tokens = self.sample_draft(
last_hidden_states,
positions,
idx_mapping,
self.temperature,
self.seeds,
self.current_draft_step,
self.draft_logits,
)
# Update the inputs for the next step.
update_draft_inputs(
draft_tokens,
self.current_draft_step,
hidden_states,
self.draft_tokens,
self.hidden_states,
self.input_buffers,
num_reqs,
self.max_model_len,
self.num_speculative_steps,
advance_draft_positions=self.advance_draft_positions,
)