Skip to content

vllm_omni.diffusion.cache.prompt_embed_cache

Prompt-embedding cache for diffusion pipelines.

Text encoders are the single most expensive per-request preprocessing step in most diffusion pipelines, and their inputs are frequently identical across requests (e.g. in GRPO-style rollouts the same prompt is submitted many times with different seeds to sample different images). Re-running the text encoder for each of those requests wastes compute and GPU memory.

This module provides a small LRU cache plus a transparent wrapper for a pipeline's encode_prompt method. Because almost every diffusion pipeline in vllm_omni/diffusion/models routes all text-encoder invocations through self.encode_prompt(...), wrapping that single method is sufficient to cache results model-wide without per-pipeline edits.

Design points
  • The wrapper is installed by :class:DiffusionModelRunner after the pipeline has loaded, so each runner process owns its own cache.
  • Cache keys are derived from the bound encode_prompt arguments. Only inputs we can safely hash (str / int / float / bool / None / bytes / torch.device / torch.dtype / numpy scalars / nested lists/tuples/dicts of those) participate in the key. If any argument is a tensor, PIL image, or other non-trivial object, we bypass the cache for that call to guarantee correctness.
  • If a caller passes precomputed *_embeds into encode_prompt the wrapper also bypasses the cache because the call is already short-circuit.
  • Cache values are detached tensors. Downstream pipeline code typically does non-inplace ops (.repeat, .view, slicing); we do not clone on hit since outputs are treated as read-only.

logger module-attribute

logger = init_logger(__name__)

PromptEmbedCache

Thread-safe LRU cache for encode_prompt outputs.

The cache stores whatever encode_prompt returns (tensor, tuple of tensors, None, etc.). Lookup / insertion is O(1) amortized. Eviction is least-recently-used.

bypassed instance-attribute

bypassed = 0

enabled instance-attribute

enabled = enabled

hits instance-attribute

hits = 0

max_size instance-attribute

max_size = max_size

misses instance-attribute

misses = 0

clear

clear() -> None

get

get(key: Any) -> Any

Return the cached value or _CACHE_MISS if absent.

Returning a sentinel (rather than None) lets callers cache legitimate None results from the wrapped function.

put

put(key: Any, value: Any) -> None

stats

stats() -> dict[str, int]

install_prompt_embed_cache

install_prompt_embed_cache(
    pipeline: Any,
    *,
    max_size: int = 32,
    enabled: bool = True,
    model_tag: str | None = None,
) -> PromptEmbedCache | None

Wrap pipeline.encode_prompt so results are cached by argument identity.

Idempotent: calling twice on the same pipeline is a no-op and returns the existing cache. Returns None if the pipeline has no encode_prompt method.

resolve_prompt_embed_cache_config

resolve_prompt_embed_cache_config(
    enable: bool | None = None, max_size: int | None = None
) -> tuple[bool, int]

Combine explicit args with env-var overrides.

Environment variables (useful for quick enablement in GRPO jobs without touching config files):

``OMNI_DIFFUSION_PROMPT_EMBED_CACHE`` (``1``/``0``/``true``/``false``)
``OMNI_DIFFUSION_PROMPT_EMBED_CACHE_SIZE`` (positive int)

uninstall_prompt_embed_cache

uninstall_prompt_embed_cache(pipeline: Any) -> None

Restore the original encode_prompt on pipeline if wrapped.