vllm_omni.diffusion.cache ¶
Cache module for diffusion model inference acceleration.
This module provides a unified cache backend system for different caching strategies: - TeaCache: Timestep Embedding Aware Cache for adaptive transformer caching - MagCache: Magnitude-based Cache for adaptive transformer caching - cache-dit: DBCache, SCM, and TaylorSeer caching strategies
Cache backends are instantiated directly via their constructors and configured via OmniDiffusionConfig.
Modules:
| Name | Description |
|---|---|
base | Base cache backend interface for diffusion models. |
cache_dit_backend | cache-dit integration backend for vllm-omni. |
magcache | |
prompt_embed_cache | Prompt-embedding cache for diffusion pipelines. |
selector | |
teacache | TeaCache: Timestep Embedding Aware Cache for diffusion model acceleration. |
CacheBackend ¶
Bases: ABC
Abstract base class for cache backends.
All cache backend implementations (CacheDiTBackend, TeaCacheBackend, etc.) inherit from this base class and implement the enable() and refresh() methods to manage cache lifecycle.
Cache backends apply caching strategies to transformer models to accelerate inference. Different backends use different underlying mechanisms (e.g., cache-dit library for CacheDiTBackend, hooks for TeaCacheBackend), but all share the same unified interface.
Attributes:
| Name | Type | Description |
|---|---|---|
config | DiffusionCacheConfig instance containing cache-specific configuration parameters | |
enabled | Boolean flag indicating whether cache is enabled (set to True after enable() is called) |
enable abstractmethod ¶
enable(pipeline: Any) -> None
Enable cache on the pipeline.
This method applies the caching strategy to the transformer(s) in the pipeline. The specific implementation depends on the backend (e.g., hooks for TeaCacheBackend, cache-dit library for CacheDiTBackend). Called once during pipeline initialization.
Parameters:
| Name | Type | Description | Default |
|---|---|---|---|
pipeline | Any | Diffusion pipeline instance. The backend can extract: - transformer: via pipeline.transformer - model_type: via pipeline.class.name | required |
is_enabled ¶
is_enabled() -> bool
Check if cache is enabled on this backend.
Returns:
| Type | Description |
|---|---|
bool | True if cache is enabled, False otherwise. |
refresh abstractmethod ¶
Refresh cache state for new generation.
This method should clear any cached values and reset counters/accumulators. Called at the start of each generation to ensure clean state.
Parameters:
| Name | Type | Description | Default |
|---|---|---|---|
pipeline | Any | Diffusion pipeline instance. The backend can extract: - transformer: via pipeline.transformer | required |
num_inference_steps | int | Number of inference steps for the current generation. May be used for cache context updates. | required |
verbose | bool | Whether to log refresh operations (default: True) | True |
CacheContext dataclass ¶
Context object containing all model-specific information for caching.
This allows the TeaCacheHook to remain completely generic - all model-specific logic is encapsulated in the extractor that returns this context.
Attributes:
| Name | Type | Description |
|---|---|---|
modulated_input | Tensor | Tensor used for cache decision (similarity comparison). Must be a torch.Tensor extracted from the first transformer block, typically after applying normalization and modulation. |
hidden_states | Tensor | Current hidden states (will be modified by caching). Must be a torch.Tensor representing the main image/latent states after preprocessing but before transformer blocks. |
encoder_hidden_states | Tensor | None | Optional encoder states (for dual-stream models). Set to None for single-stream models (e.g., Flux). For dual-stream models (e.g., Qwen), contains text encoder outputs. |
temb | Tensor | Timestep embedding tensor. Must be a torch.Tensor containing the timestep conditioning. |
run_transformer_blocks | Callable[[], tuple[Tensor, ...]] | Callable that executes model-specific transformer blocks. Signature: () -> tuple[torch.Tensor, ...] Returns: tuple containing: - [0]: processed hidden_states (required) - [1]: processed encoder_hidden_states (optional, only for dual-stream) Example for single-stream: def run_blocks(): h = hidden_states for block in module.transformer_blocks: h = block(h, temb=temb) return (h,) Example for dual-stream: def run_blocks(): h, e = hidden_states, encoder_hidden_states for block in module.transformer_blocks: e, h = block(h, e, temb=temb) return (h, e) |
postprocess | Callable[[Tensor], Any] | Callable that does model-specific output postprocessing. Signature: (torch.Tensor) -> Union[torch.Tensor, Transformer2DModelOutput, tuple] Takes the processed hidden_states and applies final transformations (normalization, projection) to produce the model output. Example: def postprocess(h): h = module.norm_out(h, temb) output = module.proj_out(h) return Transformer2DModelOutput(sample=output) |
extra_states | dict[str, Any] | None | Optional dict for additional model-specific state. Use this for models that need to pass additional context beyond the standard fields. |
run_transformer_blocks instance-attribute ¶
validate ¶
Validate that the CacheContext contains valid data.
Raises:
| Type | Description |
|---|---|
TypeError | If fields have wrong types |
ValueError | If tensors have invalid properties |
RuntimeError | If callables fail basic invocation tests |
This method should be called after creating a CacheContext to catch common developer errors early with clear error messages.
PromptEmbedCache ¶
Thread-safe LRU cache for encode_prompt outputs.
The cache stores whatever encode_prompt returns (tensor, tuple of tensors, None, etc.). Lookup / insertion is O(1) amortized. Eviction is least-recently-used.
get ¶
Return the cached value or _CACHE_MISS if absent.
Returning a sentinel (rather than None) lets callers cache legitimate None results from the wrapped function.
TeaCacheBackend ¶
Bases: CacheBackend
TeaCache implementation using hooks.
TeaCache (Timestep Embedding Aware Cache) is an adaptive caching technique that speeds up diffusion inference by reusing transformer block computations when consecutive timestep embeddings are similar.
The backend applies TeaCache hooks to the transformer which intercept the forward pass and implement the caching logic transparently.
Example
from vllm_omni.diffusion.data import DiffusionCacheConfig backend = TeaCacheBackend(DiffusionCacheConfig(rel_l1_thresh=0.2)) backend.enable(pipeline)
Generate with cache enabled¶
backend.refresh(pipeline, num_inference_steps=50) # Refresh before each generation
Access config attributes: backend.config.rel_l1_thresh¶
enable ¶
enable(pipeline: Any) -> None
Enable TeaCache on transformer using hooks.
This creates a TeaCacheConfig from the backend's DiffusionCacheConfig and applies the TeaCache hook to the transformer.
Parameters:
| Name | Type | Description | Default |
|---|---|---|---|
pipeline | Any | Diffusion pipeline instance. Extracts transformer and transformer_type: - transformer: pipeline.transformer - transformer_type: pipeline.transformer.class.name | required |
refresh ¶
Refresh TeaCache state for new generation.
Clears all cached residuals and resets counters/accumulators. Should be called before each generation to ensure clean state.
Parameters:
| Name | Type | Description | Default |
|---|---|---|---|
pipeline | Any | Diffusion pipeline instance. Extracts transformer via pipeline.transformer. | required |
num_inference_steps | int | Number of inference steps for the current generation. Currently not used by TeaCache but accepted for interface consistency. | required |
verbose | bool | Whether to log refresh operations (default: True) | True |
TeaCacheConfig dataclass ¶
Configuration for TeaCache applied to transformer models.
TeaCache (Timestep Embedding Aware Cache) is an adaptive caching technique that speeds up diffusion model inference by reusing transformer block computations when consecutive timestep embeddings are similar.
Parameters:
| Name | Type | Description | Default |
|---|---|---|---|
rel_l1_thresh | float | Threshold for accumulated relative L1 distance. When below threshold, cached residual is reused. Values in [0.1, 0.3] work best: - 0.2: ~1.5x speedup with minimal quality loss - 0.4: ~1.8x speedup with slight quality loss - 0.6: ~2.0x speedup with noticeable quality loss | 0.2 |
coefficients | list[float] | None | Polynomial coefficients for rescaling L1 distance. If None, uses model-specific defaults based on transformer_type. | None |
transformer_type | str | Transformer class name (e.g., "QwenImageTransformer2DModel"). Auto-detected from pipeline.transformer.class.name in backend. Defaults to "QwenImageTransformer2DModel". | 'QwenImageTransformer2DModel' |
apply_teacache_hook ¶
apply_teacache_hook(
module: Module, config: TeaCacheConfig
) -> None
Apply TeaCache optimization to a transformer module.
This function registers a TeaCacheHook that completely intercepts the module's forward pass, implementing adaptive caching without any changes to the model code.
Parameters:
| Name | Type | Description | Default |
|---|---|---|---|
module | Module | Transformer model to optimize (e.g., QwenImageTransformer2DModel) | required |
config | TeaCacheConfig | TeaCacheConfig specifying caching parameters | required |
Example
config = TeaCacheConfig( ... rel_l1_thresh=0.2, ... transformer_type="QwenImageTransformer2DModel" ... ) apply_teacache_hook(transformer, config)
Transformer bound to the pipeline now uses TeaCache automatically,¶
... # no code changes needed!
install_prompt_embed_cache ¶
install_prompt_embed_cache(
pipeline: Any,
*,
max_size: int = 32,
enabled: bool = True,
model_tag: str | None = None,
) -> PromptEmbedCache | None
Wrap pipeline.encode_prompt so results are cached by argument identity.
Idempotent: calling twice on the same pipeline is a no-op and returns the existing cache. Returns None if the pipeline has no encode_prompt method.
resolve_prompt_embed_cache_config ¶
resolve_prompt_embed_cache_config(
enable: bool | None = None, max_size: int | None = None
) -> tuple[bool, int]
Combine explicit args with env-var overrides.
Environment variables (useful for quick enablement in GRPO jobs without touching config files):
``OMNI_DIFFUSION_PROMPT_EMBED_CACHE`` (``1``/``0``/``true``/``false``)
``OMNI_DIFFUSION_PROMPT_EMBED_CACHE_SIZE`` (positive int)