class EncoderDecoderModelState(ModelState):
"""ModelState for cross-attention encoder-decoder models
(Whisper, CohereASR, NemotronParse, FireRedLID, ...)
"""
def __init__(
self,
vllm_config: VllmConfig,
model: nn.Module,
encoder_cache: EncoderCache | None,
device: torch.device,
) -> None:
self.vllm_config = vllm_config
self.model_config = vllm_config.model_config
self.scheduler_config = vllm_config.scheduler_config
self.model = model
self.max_num_reqs = vllm_config.scheduler_config.max_num_seqs
self.max_num_tokens = self.scheduler_config.max_num_batched_tokens
self.max_model_len = self.model_config.max_model_len
self.device = device
assert encoder_cache is not None
self.encoder_cache = encoder_cache
self.encoder_runner = EncoderRunner(
model=self.model,
max_num_tokens=self.max_num_tokens,
hidden_size=self.model_config.get_inputs_embeds_size(),
encoder_cache=self.encoder_cache,
dtype=self.model_config.dtype,
device=self.device,
)
self.max_encoder_len = getattr(
self.model_config.hf_config,
"max_source_positions",
self.max_model_len,
)
self.encoder_seq_lens_gpu = torch.zeros(
self.max_num_reqs, dtype=torch.int32, device=self.device
)
self.encoder_outputs: list[torch.Tensor] = []
def get_mm_embeddings(
self, scheduled_encoder_inputs: dict[str, list[int]], input_batch: InputBatch
) -> None:
# Ensure encoder inputs are ordered consistently with input_batch.req_ids.
encoder_inputs: dict[str, list[int]] = {}
for req_id in input_batch.req_ids:
req_encoder_inputs = scheduled_encoder_inputs.get(req_id, [])
if req_encoder_inputs:
encoder_inputs[req_id] = req_encoder_inputs
_, mm_kwargs = self.encoder_runner.prepare_mm_inputs(encoder_inputs)
if mm_kwargs:
# Encoder-decoder models consume encoder outputs through the
# `encoder_outputs` forward kwarg, not `inputs_embeds`. Single modality
# so execute_mm_encoder preserves request order; use its return value
# directly. No need to store in encoder_cache: cross-attention K/V are
# written to the KV cache on the first step; decode steps use the cache.
self.encoder_outputs = self.encoder_runner.execute_mm_encoder(mm_kwargs)
else:
# Decode steps: encoder K/V are in cross-attention KV cache.
self.encoder_outputs = []
return None
def prepare_inputs(
self, input_batch: InputBatch, req_states: RequestState
) -> dict[str, Any]:
model_inputs = {"encoder_outputs": self.encoder_outputs}
self.encoder_outputs = []
return model_inputs
def prepare_dummy_inputs(self, num_reqs: int, num_tokens: int) -> dict[str, Any]:
return {"encoder_outputs": []}
def prepare_attn(
self,
input_batch: InputBatch,
cudagraph_mode: CUDAGraphMode,
block_tables: tuple[torch.Tensor, ...],
slot_mappings: torch.Tensor,
attn_groups: list[list[AttentionGroup]],
kv_cache_config: KVCacheConfig,
for_capture: bool = False,
) -> dict[str, Any]:
if cudagraph_mode == CUDAGraphMode.FULL:
num_reqs = input_batch.num_reqs_after_padding
num_tokens = input_batch.num_tokens_after_padding
else:
num_reqs = input_batch.num_reqs
num_tokens = input_batch.num_tokens
enc_dec_attn_metadata = EncoderDecoderAttnMetadata(
self._get_encoder_seq_lens(
input_batch.req_ids, attn_groups, for_capture, num_reqs
)
)
query_start_loc_cpu = torch.from_numpy(input_batch.query_start_loc_np)
max_query_len = input_batch.num_scheduled_tokens.max().item()
seq_lens_cpu_upper_bound = input_batch.seq_lens_cpu_upper_bound
if for_capture:
max_seq_len = self.max_model_len
else:
max_seq_len = int(seq_lens_cpu_upper_bound[:num_reqs].max().item())
attn_metadata = build_attn_metadata(
attn_groups=attn_groups,
num_reqs=num_reqs,
num_tokens=num_tokens,
query_start_loc_gpu=input_batch.query_start_loc,
query_start_loc_cpu=query_start_loc_cpu,
max_query_len=max_query_len,
seq_lens=input_batch.seq_lens,
max_seq_len=max_seq_len,
block_tables=block_tables,
slot_mappings=slot_mappings,
kv_cache_config=kv_cache_config,
seq_lens_cpu_upper_bound=seq_lens_cpu_upper_bound,
dcp_local_seq_lens=input_batch.dcp_local_seq_lens,
model_specific_attn_metadata=enc_dec_attn_metadata,
for_cudagraph_capture=for_capture,
)
return attn_metadata
def _get_encoder_seq_lens(
self,
req_ids: list[str],
attn_groups: list[list[AttentionGroup]],
for_capture: bool,
num_reqs: int,
) -> dict[int, tuple[torch.Tensor, np.ndarray]]:
encoder_seq_lens = torch.zeros(num_reqs, dtype=torch.int32, pin_memory=True)
encoder_seq_lens_np = encoder_seq_lens.numpy()
if not for_capture:
# During normal execution, use actual encoder lengths.
for i, req_id in enumerate(req_ids):
mm_features = self.encoder_cache.mm_features.get(req_id, [])
encoder_seq_lens_np[i] = sum(
feature.mm_position.get_num_embeds() for feature in mm_features
)
else:
# During CUDA graph capture, use max encoder length so max_seqlen_k
# is captured with the correct value for cross-attention.
encoder_seq_lens_np[:] = self.max_encoder_len
self.encoder_seq_lens_gpu[:num_reqs].copy_(encoder_seq_lens, non_blocking=True)
self.encoder_seq_lens_gpu[num_reqs:].fill_(0)
encoder_seq_lens_gpu = self.encoder_seq_lens_gpu[:num_reqs]
seq_lens_by_group: dict[int, tuple[torch.Tensor, np.ndarray]] = {}
for kv_cache_group_idx, groups in enumerate(attn_groups):
has_cross_attn = any(
isinstance(attn_group.kv_cache_spec, CrossAttentionSpec)
for attn_group in groups
)
if has_cross_attn:
seq_lens_by_group[kv_cache_group_idx] = (
encoder_seq_lens_gpu,
encoder_seq_lens_np,
)
return seq_lens_by_group