class DefaultModelState(ModelState):
def __init__(
self,
vllm_config: VllmConfig,
model: nn.Module,
encoder_cache: EncoderCache | None,
device: torch.device,
):
super().__init__(vllm_config, model, encoder_cache, device)
self.rope_state = get_rope_state(
self.model_config,
model,
max_num_reqs=self.max_num_reqs,
max_num_tokens=self.max_num_tokens,
max_model_len=self.max_model_len,
device=self.device,
)
# Pruner is used for multimodal embedding pruning (EVS).
self.mm_pruner = maybe_create_mm_pruner(
self.model_config, model, self.rope_state, encoder_cache
)
def add_request(self, req_index: int, new_req_data: NewRequestData) -> None:
if self.rope_state is not None:
assert new_req_data.prefill_token_ids is not None
self.rope_state.init_prefill_positions(
req_index,
self.model,
new_req_data.prefill_token_ids,
mm_features=new_req_data.mm_features,
)
def apply_staged_writes(self) -> None:
if self.rope_state is not None:
self.rope_state.apply_staged_writes()
def dummy_inputs_embeds(self, num_tokens: int) -> torch.Tensor:
"""Pre-allocated inputs_embeds buffer for dummy runs (contents unused)."""
return self.encoder_runner.inputs_embeds[:num_tokens]
def get_mm_embeddings(
self,
scheduled_encoder_inputs: dict[str, list[int]],
input_batch: InputBatch,
req_states: RequestState,
) -> torch.Tensor:
mm_hashes, mm_kwargs = self.encoder_runner.prepare_mm_inputs(
scheduled_encoder_inputs
)
if mm_kwargs:
# Execute the multimodal encoder.
encoder_outputs = self.encoder_runner.execute_mm_encoder(mm_kwargs)
# Cache the encoder outputs by mm_hash
self.encoder_cache.encoder_outputs.update(zip(mm_hashes, encoder_outputs))
mm_embeds, is_mm_embed = super().gather_mm_embeddings(input_batch)
if self.mm_pruner is not None and mm_embeds:
# EVS: recompute mrope positions for pruned media.
mm_embeds = self.mm_pruner.recompute(mm_embeds, input_batch, req_states)
# We must flush the staged rope updates for prepare_inputs() to pick up.
self.apply_staged_writes()
# Use unpadded input_ids to match is_mm_embed size (num_tokens).
# input_batch.input_ids may be padded for CUDA graphs.
input_ids_unpadded = input_batch.input_ids[: input_batch.num_tokens]
inputs_embeds = self.encoder_runner.get_inputs_embeds(
input_ids_unpadded, mm_embeds, is_mm_embed
)
return inputs_embeds[: input_batch.num_tokens_after_padding]
def gather_mm_embeddings(
self, input_batch: InputBatch, draft_lookahead: int = 0
) -> tuple[list[torch.Tensor], torch.Tensor]:
mm_embeds, is_mm_embed = super().gather_mm_embeddings(
input_batch, draft_lookahead
)
if self.mm_pruner is not None:
# EVS: strip the appended mrope-position channels.
mm_embeds = self.mm_pruner.strip(mm_embeds)
return mm_embeds, is_mm_embed
def prepare_inputs(
self, input_batch: InputBatch, req_states: RequestState
) -> dict[str, torch.Tensor | None]:
if self.rope_state is None:
return {} # Common case (1D positions).
self.rope_state.prepare_positions(
input_batch.idx_mapping,
input_batch.query_start_loc,
req_states.prefill_len.gpu,
req_states.num_computed_tokens.gpu,
)
positions = self.rope_state.get_positions(input_batch.num_tokens_after_padding)
return {"positions": positions}
def prepare_dummy_inputs(self, num_reqs: int, num_tokens: int) -> dict[str, Any]:
model_inputs = {}
if self.supports_mm_inputs:
inputs_embeds = self.encoder_runner.inputs_embeds[:num_tokens]
model_inputs["inputs_embeds"] = inputs_embeds
if self.rope_state is not None:
model_inputs["positions"] = self.rope_state.get_positions(num_tokens)
return model_inputs
def prepare_attn(
self,
input_batch: InputBatch,
cudagraph_mode: CUDAGraphMode,
block_tables: tuple[torch.Tensor, ...],
slot_mappings: torch.Tensor,
attn_groups: list[list[AttentionGroup]],
kv_cache_config: KVCacheConfig,
for_capture: bool = False,
) -> dict[str, Any]:
if cudagraph_mode == CUDAGraphMode.FULL:
# Use padded sizes - padding is handled by model_runner.prepare_attn.
num_reqs = input_batch.num_reqs_after_padding
num_tokens = input_batch.num_tokens_after_padding
else:
# For piecewise cudagraphs and eager, use unpadded sizes.
num_reqs = input_batch.num_reqs
num_tokens = input_batch.num_tokens
query_start_loc_cpu = torch.from_numpy(input_batch.query_start_loc_np)
max_query_len = input_batch.num_scheduled_tokens.max().item()
seq_lens_cpu_upper_bound = input_batch.seq_lens_cpu_upper_bound
if for_capture:
# Capture with worst-case max_seq_len so the graph is valid at any replay.
max_seq_len = self.max_model_len
else:
max_seq_len = seq_lens_cpu_upper_bound[:num_reqs].max().item()
attn_metadata = build_attn_metadata(
attn_groups=attn_groups,
num_reqs=num_reqs,
num_tokens=num_tokens,
query_start_loc_gpu=input_batch.query_start_loc,
query_start_loc_cpu=query_start_loc_cpu,
max_query_len=max_query_len,
seq_lens=input_batch.seq_lens,
max_seq_len=max_seq_len,
block_tables=block_tables,
slot_mappings=slot_mappings,
kv_cache_config=kv_cache_config,
seq_lens_cpu_upper_bound=seq_lens_cpu_upper_bound,
dcp_local_seq_lens=input_batch.dcp_local_seq_lens,
positions=input_batch.positions,
for_cudagraph_capture=for_capture,
rswa_prefix_lens=input_batch.prompt_lens,
)
return attn_metadata