Skip to content

vllm_omni.core.prefix_cache

Utilities for Prefix Caching in Omni models.

logger module-attribute

logger = init_logger(__name__)

OmniTensorPrefixCache

Prefix cache for hidden states (model outputs) and model specific multimodal outputs.

This class implements prefix caching in a non-invasive way on top of vLLM by leveraging the same slot mappings that the vLLM scheduler uses for the KV Cache.

Conceptually, this means we are mapping vLLM's cache mapping: (num_blocks, block_size)

to 3D tensors of shape

(num_blocks, block_size, feature_size)

Note that feature_size may vary across multimodal_outputs.

block_size instance-attribute

block_size = block_size

default_hidden_size instance-attribute

default_hidden_size = hidden_size

hidden_states_cache instance-attribute

hidden_states_cache = _get_cache_tensor(dtype=hs_dtype)

mm_cache_keys instance-attribute

mm_cache_keys = set()

mm_outputs_cache instance-attribute

mm_outputs_cache = {}

num_blocks instance-attribute

num_blocks = num_blocks

add_prefix_cached_new_req_id

add_prefix_cached_new_req_id(req_id: str)

Adds a new request ID to the set of prefix cache hits on the batch.

commit_deferred_mm_outputs

commit_deferred_mm_outputs(
    finished_req_ids: set[str] | list[str],
    input_batch: InputBatch,
) -> None

Write deferred multimodal chunks into the CPU prefix cache.

This must run before finished requests are removed from input_batch, because the block table is needed to map logical token positions to KV cache slots.

discard_deferred_mm_outputs

discard_deferred_mm_outputs(req_id: str) -> None

Drop deferred chunks for requests that leave without a cache commit.

get_merged_hidden_states

get_merged_hidden_states(
    query_start_loc: Tensor,
    input_batch: InputBatch,
    hidden_states: Tensor,
    num_scheduled_tokens: dict[str, int],
    hidden_states_cpu: Tensor | None = None,
) -> dict[str, Tensor]

Get merged hidden states, optionally reusing pre-staged CPU states.

When provided, hidden_states_cpu follows the same contract as update_omni_tensor_prefix_cache: CPU, contiguous, same dtype and feature shape as hidden_states, and covering every scheduled-token span derived from query_start_loc and num_scheduled_tokens.

get_merged_multimodal_states

get_merged_multimodal_states(
    query_start_loc: Tensor,
    input_batch: InputBatch,
    multimodal_outputs: dict,
    num_scheduled_tokens: dict[str, int],
)

Get the merged multimodal states if hidden state prefix caching is enabled.

has_prefix_cached_new_req_ids

has_prefix_cached_new_req_ids() -> bool

Return True when this step contains a newly scheduled prefix hit.

maybe_init_missing_mm_cache_keys

maybe_init_missing_mm_cache_keys(
    multimodal_outputs: dict, seq_len: int
)

Given multimodal outputs from executing the model, dynamically determine which multimodal outputs are tensors depending on sequence length and should be cached, and initialize the cache tensors accordingly.

NOTE: This is done to avoid the need for explicit specification of cache keys for every model/stage and aligns with the current way that we slice the multimodal outputs based on the first dimension.

This will usually be called by the first forward pass, i.e., determined by the warmup.

reset_prefix_cached_new_req_ids

reset_prefix_cached_new_req_ids()

Clears the cache hit IDs to prepare for a new engine step.

stage_deferred_mm_outputs

stage_deferred_mm_outputs(
    query_start_loc: Tensor,
    input_batch: InputBatch,
    multimodal_outputs: dict[str, Tensor] | None,
    num_scheduled_tokens: dict[str, int],
    deferred_mm_cache_keys: set[str],
) -> None

Keep GPU multimodal chunks until a request finishes.

The normal prefix-cache write path copies every cached multimodal tensor to CPU on every step. For model outputs that are only needed by future full-block prefix hits, we can keep detached GPU chunks and materialize the CPU cache once when the request completes.

update_omni_tensor_prefix_cache

update_omni_tensor_prefix_cache(
    hidden_states: Tensor | None,
    multimodal_outputs: dict[str, Tensor] | None,
    num_tokens_unpadded: int,
    slot_mapping: Tensor,
    num_tokens_padded: int | None = None,
    skip_mm_cache_keys: set[str] | None = None,
    hidden_states_cpu: Tensor | None = None,
)

Updates the hidden cache state for the provided hidden states and multimodal outputs.

Parameters:

Name Type Description Default
hidden_states Tensor | None

Hidden states tensor to cache (if any)

required
multimodal_outputs dict[str, Tensor] | None

Multimodal dict whose tensors may be cached

required
num_tokens_unpadded int

Number of tokens without padding

required
slot_mapping Tensor

Slot mapping for the input sequence

required
num_tokens_padded int | None

Total number of tokens including padding

None
skip_mm_cache_keys set[str] | None

Multimodal keys whose CPU cache writes are deferred

None
hidden_states_cpu Tensor | None

Optional pre-staged CPU view of hidden_states. When provided, it must be contiguous, live on CPU, match the feature shape of hidden_states, and cover num_tokens_unpadded.

None