Skip to content

vllm_omni.diffusion.cache

Cache module for diffusion model inference acceleration.

This module provides a unified cache backend system for different caching strategies: - TeaCache: Timestep Embedding Aware Cache for adaptive transformer caching - MagCache: Magnitude-based Cache for adaptive transformer caching - cache-dit: DBCache, SCM, and TaylorSeer caching strategies

Cache backends are instantiated directly via their constructors and configured via OmniDiffusionConfig.

Modules:

Name Description
base

Base cache backend interface for diffusion models.

cache_dit_backend

cache-dit integration backend for vllm-omni.

magcache
prompt_embed_cache

Prompt-embedding cache for diffusion pipelines.

selector
teacache

TeaCache: Timestep Embedding Aware Cache for diffusion model acceleration.

CacheBackend

Bases: ABC

Abstract base class for cache backends.

All cache backend implementations (CacheDiTBackend, TeaCacheBackend, etc.) inherit from this base class and implement the enable() and refresh() methods to manage cache lifecycle.

Cache backends apply caching strategies to transformer models to accelerate inference. Different backends use different underlying mechanisms (e.g., cache-dit library for CacheDiTBackend, hooks for TeaCacheBackend), but all share the same unified interface.

Attributes:

Name Type Description
config

DiffusionCacheConfig instance containing cache-specific configuration parameters

enabled

Boolean flag indicating whether cache is enabled (set to True after enable() is called)

config instance-attribute

config = config

enabled instance-attribute

enabled = False

enable abstractmethod

enable(pipeline: Any) -> None

Enable cache on the pipeline.

This method applies the caching strategy to the transformer(s) in the pipeline. The specific implementation depends on the backend (e.g., hooks for TeaCacheBackend, cache-dit library for CacheDiTBackend). Called once during pipeline initialization.

Parameters:

Name Type Description Default
pipeline Any

Diffusion pipeline instance. The backend can extract: - transformer: via pipeline.transformer - model_type: via pipeline.class.name

required

is_enabled

is_enabled() -> bool

Check if cache is enabled on this backend.

Returns:

Type Description
bool

True if cache is enabled, False otherwise.

refresh abstractmethod

refresh(
    pipeline: Any,
    num_inference_steps: int,
    verbose: bool = True,
) -> None

Refresh cache state for new generation.

This method should clear any cached values and reset counters/accumulators. Called at the start of each generation to ensure clean state.

Parameters:

Name Type Description Default
pipeline Any

Diffusion pipeline instance. The backend can extract: - transformer: via pipeline.transformer

required
num_inference_steps int

Number of inference steps for the current generation. May be used for cache context updates.

required
verbose bool

Whether to log refresh operations (default: True)

True

CacheContext dataclass

Context object containing all model-specific information for caching.

This allows the TeaCacheHook to remain completely generic - all model-specific logic is encapsulated in the extractor that returns this context.

Attributes:

Name Type Description
modulated_input Tensor

Tensor used for cache decision (similarity comparison). Must be a torch.Tensor extracted from the first transformer block, typically after applying normalization and modulation.

hidden_states Tensor

Current hidden states (will be modified by caching). Must be a torch.Tensor representing the main image/latent states after preprocessing but before transformer blocks.

encoder_hidden_states Tensor | None

Optional encoder states (for dual-stream models). Set to None for single-stream models (e.g., Flux). For dual-stream models (e.g., Qwen), contains text encoder outputs.

temb Tensor

Timestep embedding tensor. Must be a torch.Tensor containing the timestep conditioning.

run_transformer_blocks Callable[[], tuple[Tensor, ...]]

Callable that executes model-specific transformer blocks. Signature: () -> tuple[torch.Tensor, ...]

Returns: tuple containing: - [0]: processed hidden_states (required) - [1]: processed encoder_hidden_states (optional, only for dual-stream)

Example for single-stream: def run_blocks(): h = hidden_states for block in module.transformer_blocks: h = block(h, temb=temb) return (h,)

Example for dual-stream: def run_blocks(): h, e = hidden_states, encoder_hidden_states for block in module.transformer_blocks: e, h = block(h, e, temb=temb) return (h, e)

postprocess Callable[[Tensor], Any]

Callable that does model-specific output postprocessing. Signature: (torch.Tensor) -> Union[torch.Tensor, Transformer2DModelOutput, tuple]

Takes the processed hidden_states and applies final transformations (normalization, projection) to produce the model output.

Example: def postprocess(h): h = module.norm_out(h, temb) output = module.proj_out(h) return Transformer2DModelOutput(sample=output)

extra_states dict[str, Any] | None

Optional dict for additional model-specific state. Use this for models that need to pass additional context beyond the standard fields.

encoder_hidden_states instance-attribute

encoder_hidden_states: Tensor | None

extra_states class-attribute instance-attribute

extra_states: dict[str, Any] | None = None

hidden_states instance-attribute

hidden_states: Tensor

modulated_input instance-attribute

modulated_input: Tensor

postprocess instance-attribute

postprocess: Callable[[Tensor], Any]

run_transformer_blocks instance-attribute

run_transformer_blocks: Callable[[], tuple[Tensor, ...]]

temb instance-attribute

temb: Tensor

validate

validate() -> None

Validate that the CacheContext contains valid data.

Raises:

Type Description
TypeError

If fields have wrong types

ValueError

If tensors have invalid properties

RuntimeError

If callables fail basic invocation tests

This method should be called after creating a CacheContext to catch common developer errors early with clear error messages.

PromptEmbedCache

Thread-safe LRU cache for encode_prompt outputs.

The cache stores whatever encode_prompt returns (tensor, tuple of tensors, None, etc.). Lookup / insertion is O(1) amortized. Eviction is least-recently-used.

bypassed instance-attribute

bypassed = 0

enabled instance-attribute

enabled = enabled

hits instance-attribute

hits = 0

max_size instance-attribute

max_size = max_size

misses instance-attribute

misses = 0

clear

clear() -> None

get

get(key: Any) -> Any

Return the cached value or _CACHE_MISS if absent.

Returning a sentinel (rather than None) lets callers cache legitimate None results from the wrapped function.

put

put(key: Any, value: Any) -> None

stats

stats() -> dict[str, int]

TeaCacheBackend

Bases: CacheBackend

TeaCache implementation using hooks.

TeaCache (Timestep Embedding Aware Cache) is an adaptive caching technique that speeds up diffusion inference by reusing transformer block computations when consecutive timestep embeddings are similar.

The backend applies TeaCache hooks to the transformer which intercept the forward pass and implement the caching logic transparently.

Example

from vllm_omni.diffusion.data import DiffusionCacheConfig backend = TeaCacheBackend(DiffusionCacheConfig(rel_l1_thresh=0.2)) backend.enable(pipeline)

Generate with cache enabled

backend.refresh(pipeline, num_inference_steps=50) # Refresh before each generation

Access config attributes: backend.config.rel_l1_thresh

enable

enable(pipeline: Any) -> None

Enable TeaCache on transformer using hooks.

This creates a TeaCacheConfig from the backend's DiffusionCacheConfig and applies the TeaCache hook to the transformer.

Parameters:

Name Type Description Default
pipeline Any

Diffusion pipeline instance. Extracts transformer and transformer_type: - transformer: pipeline.transformer - transformer_type: pipeline.transformer.class.name

required

refresh

refresh(
    pipeline: Any,
    num_inference_steps: int,
    verbose: bool = True,
) -> None

Refresh TeaCache state for new generation.

Clears all cached residuals and resets counters/accumulators. Should be called before each generation to ensure clean state.

Parameters:

Name Type Description Default
pipeline Any

Diffusion pipeline instance. Extracts transformer via pipeline.transformer.

required
num_inference_steps int

Number of inference steps for the current generation. Currently not used by TeaCache but accepted for interface consistency.

required
verbose bool

Whether to log refresh operations (default: True)

True

TeaCacheConfig dataclass

Configuration for TeaCache applied to transformer models.

TeaCache (Timestep Embedding Aware Cache) is an adaptive caching technique that speeds up diffusion model inference by reusing transformer block computations when consecutive timestep embeddings are similar.

Parameters:

Name Type Description Default
rel_l1_thresh float

Threshold for accumulated relative L1 distance. When below threshold, cached residual is reused. Values in [0.1, 0.3] work best: - 0.2: ~1.5x speedup with minimal quality loss - 0.4: ~1.8x speedup with slight quality loss - 0.6: ~2.0x speedup with noticeable quality loss

0.2
coefficients list[float] | None

Polynomial coefficients for rescaling L1 distance. If None, uses model-specific defaults based on transformer_type.

None
transformer_type str

Transformer class name (e.g., "QwenImageTransformer2DModel"). Auto-detected from pipeline.transformer.class.name in backend. Defaults to "QwenImageTransformer2DModel".

'QwenImageTransformer2DModel'

coefficients class-attribute instance-attribute

coefficients: list[float] | None = None

rel_l1_thresh class-attribute instance-attribute

rel_l1_thresh: float = 0.2

transformer_type class-attribute instance-attribute

transformer_type: str = 'QwenImageTransformer2DModel'

apply_teacache_hook

apply_teacache_hook(
    module: Module, config: TeaCacheConfig
) -> None

Apply TeaCache optimization to a transformer module.

This function registers a TeaCacheHook that completely intercepts the module's forward pass, implementing adaptive caching without any changes to the model code.

Parameters:

Name Type Description Default
module Module

Transformer model to optimize (e.g., QwenImageTransformer2DModel)

required
config TeaCacheConfig

TeaCacheConfig specifying caching parameters

required
Example

config = TeaCacheConfig( ... rel_l1_thresh=0.2, ... transformer_type="QwenImageTransformer2DModel" ... ) apply_teacache_hook(transformer, config)

Transformer bound to the pipeline now uses TeaCache automatically,

... # no code changes needed!

install_prompt_embed_cache

install_prompt_embed_cache(
    pipeline: Any,
    *,
    max_size: int = 32,
    enabled: bool = True,
    model_tag: str | None = None,
) -> PromptEmbedCache | None

Wrap pipeline.encode_prompt so results are cached by argument identity.

Idempotent: calling twice on the same pipeline is a no-op and returns the existing cache. Returns None if the pipeline has no encode_prompt method.

resolve_prompt_embed_cache_config

resolve_prompt_embed_cache_config(
    enable: bool | None = None, max_size: int | None = None
) -> tuple[bool, int]

Combine explicit args with env-var overrides.

Environment variables (useful for quick enablement in GRPO jobs without touching config files):

``OMNI_DIFFUSION_PROMPT_EMBED_CACHE`` (``1``/``0``/``true``/``false``)
``OMNI_DIFFUSION_PROMPT_EMBED_CACHE_SIZE`` (positive int)

uninstall_prompt_embed_cache

uninstall_prompt_embed_cache(pipeline: Any) -> None

Restore the original encode_prompt on pipeline if wrapped.