vllm_omni.core.prefix_cache ¶
Utilities for Prefix Caching in Omni models.
OmniTensorPrefixCache ¶
Prefix cache for hidden states (model outputs) and model specific multimodal outputs.
This class implements prefix caching in a non-invasive way on top of vLLM by leveraging the same slot mappings that the vLLM scheduler uses for the KV Cache.
Conceptually, this means we are mapping vLLM's cache mapping: (num_blocks, block_size)
to 3D tensors of shape
(num_blocks, block_size, feature_size)
Note that feature_size may vary across multimodal_outputs.
add_prefix_cached_new_req_id ¶
add_prefix_cached_new_req_id(req_id: str)
Adds a new request ID to the set of prefix cache hits on the batch.
commit_deferred_mm_outputs ¶
commit_deferred_mm_outputs(
finished_req_ids: set[str] | list[str],
input_batch: InputBatch,
) -> None
Write deferred multimodal chunks into the CPU prefix cache.
This must run before finished requests are removed from input_batch, because the block table is needed to map logical token positions to KV cache slots.
discard_deferred_mm_outputs ¶
discard_deferred_mm_outputs(req_id: str) -> None
Drop deferred chunks for requests that leave without a cache commit.
get_merged_hidden_states ¶
get_merged_hidden_states(
query_start_loc: Tensor,
input_batch: InputBatch,
hidden_states: Tensor,
num_scheduled_tokens: dict[str, int],
hidden_states_cpu: Tensor | None = None,
) -> dict[str, Tensor]
Get merged hidden states, optionally reusing pre-staged CPU states.
When provided, hidden_states_cpu follows the same contract as update_omni_tensor_prefix_cache: CPU, contiguous, same dtype and feature shape as hidden_states, and covering every scheduled-token span derived from query_start_loc and num_scheduled_tokens.
get_merged_multimodal_states ¶
get_merged_multimodal_states(
query_start_loc: Tensor,
input_batch: InputBatch,
multimodal_outputs: dict,
num_scheduled_tokens: dict[str, int],
)
Get the merged multimodal states if hidden state prefix caching is enabled.
has_prefix_cached_new_req_ids ¶
has_prefix_cached_new_req_ids() -> bool
Return True when this step contains a newly scheduled prefix hit.
maybe_init_missing_mm_cache_keys ¶
Given multimodal outputs from executing the model, dynamically determine which multimodal outputs are tensors depending on sequence length and should be cached, and initialize the cache tensors accordingly.
NOTE: This is done to avoid the need for explicit specification of cache keys for every model/stage and aligns with the current way that we slice the multimodal outputs based on the first dimension.
This will usually be called by the first forward pass, i.e., determined by the warmup.
reset_prefix_cached_new_req_ids ¶
Clears the cache hit IDs to prepare for a new engine step.
stage_deferred_mm_outputs ¶
stage_deferred_mm_outputs(
query_start_loc: Tensor,
input_batch: InputBatch,
multimodal_outputs: dict[str, Tensor] | None,
num_scheduled_tokens: dict[str, int],
deferred_mm_cache_keys: set[str],
) -> None
Keep GPU multimodal chunks until a request finishes.
The normal prefix-cache write path copies every cached multimodal tensor to CPU on every step. For model outputs that are only needed by future full-block prefix hits, we can keep detached GPU chunks and materialize the CPU cache once when the request completes.
update_omni_tensor_prefix_cache ¶
update_omni_tensor_prefix_cache(
hidden_states: Tensor | None,
multimodal_outputs: dict[str, Tensor] | None,
num_tokens_unpadded: int,
slot_mapping: Tensor,
num_tokens_padded: int | None = None,
skip_mm_cache_keys: set[str] | None = None,
hidden_states_cpu: Tensor | None = None,
)
Updates the hidden cache state for the provided hidden states and multimodal outputs.
Parameters:
| Name | Type | Description | Default |
|---|---|---|---|
hidden_states | Tensor | None | Hidden states tensor to cache (if any) | required |
multimodal_outputs | dict[str, Tensor] | None | Multimodal dict whose tensors may be cached | required |
num_tokens_unpadded | int | Number of tokens without padding | required |
slot_mapping | Tensor | Slot mapping for the input sequence | required |
num_tokens_padded | int | None | Total number of tokens including padding | None |
skip_mm_cache_keys | set[str] | None | Multimodal keys whose CPU cache writes are deferred | None |
hidden_states_cpu | Tensor | None | Optional pre-staged CPU view of hidden_states. When provided, it must be contiguous, live on CPU, match the feature shape of hidden_states, and cover num_tokens_unpadded. | None |