vllm_omni.diffusion.cache.teacache ¶
TeaCache: Timestep Embedding Aware Cache for diffusion model acceleration.
TeaCache speeds up diffusion inference by reusing transformer block computations when consecutive timestep embeddings are similar.
This implementation uses a hooks-based approach that requires zero changes to model code. Model developers only need to add an extractor function to support new models.
Usage
from vllm_omni import Omni
omni = Omni( model="Qwen/Qwen-Image", cache_backend="tea_cache", cache_config={"rel_l1_thresh": 0.2} ) images = omni.generate("a cat")
Alternative: Using environment variable¶
export DIFFUSION_CACHE_BACKEND=tea_cache¶
Modules:
| Name | Description |
|---|---|
backend | TeaCache backend implementation. |
coefficient_estimator | |
config | |
extractors | Model-specific extractors for TeaCache. |
hook | Hook-based TeaCache implementation for vLLM-Omni. |
state | TeaCache state management. |
CacheContext dataclass ¶
Context object containing all model-specific information for caching.
This allows the TeaCacheHook to remain completely generic - all model-specific logic is encapsulated in the extractor that returns this context.
Attributes:
| Name | Type | Description |
|---|---|---|
modulated_input | Tensor | Tensor used for cache decision (similarity comparison). Must be a torch.Tensor extracted from the first transformer block, typically after applying normalization and modulation. |
hidden_states | Tensor | Current hidden states (will be modified by caching). Must be a torch.Tensor representing the main image/latent states after preprocessing but before transformer blocks. |
encoder_hidden_states | Tensor | None | Optional encoder states (for dual-stream models). Set to None for single-stream models (e.g., Flux). For dual-stream models (e.g., Qwen), contains text encoder outputs. |
temb | Tensor | Timestep embedding tensor. Must be a torch.Tensor containing the timestep conditioning. |
run_transformer_blocks | Callable[[], tuple[Tensor, ...]] | Callable that executes model-specific transformer blocks. Signature: () -> tuple[torch.Tensor, ...] Returns: tuple containing: - [0]: processed hidden_states (required) - [1]: processed encoder_hidden_states (optional, only for dual-stream) Example for single-stream: def run_blocks(): h = hidden_states for block in module.transformer_blocks: h = block(h, temb=temb) return (h,) Example for dual-stream: def run_blocks(): h, e = hidden_states, encoder_hidden_states for block in module.transformer_blocks: e, h = block(h, e, temb=temb) return (h, e) |
postprocess | Callable[[Tensor], Any] | Callable that does model-specific output postprocessing. Signature: (torch.Tensor) -> Union[torch.Tensor, Transformer2DModelOutput, tuple] Takes the processed hidden_states and applies final transformations (normalization, projection) to produce the model output. Example: def postprocess(h): h = module.norm_out(h, temb) output = module.proj_out(h) return Transformer2DModelOutput(sample=output) |
extra_states | dict[str, Any] | None | Optional dict for additional model-specific state. Use this for models that need to pass additional context beyond the standard fields. |
run_transformer_blocks instance-attribute ¶
validate ¶
Validate that the CacheContext contains valid data.
Raises:
| Type | Description |
|---|---|
TypeError | If fields have wrong types |
ValueError | If tensors have invalid properties |
RuntimeError | If callables fail basic invocation tests |
This method should be called after creating a CacheContext to catch common developer errors early with clear error messages.
TeaCacheBackend ¶
Bases: CacheBackend
TeaCache implementation using hooks.
TeaCache (Timestep Embedding Aware Cache) is an adaptive caching technique that speeds up diffusion inference by reusing transformer block computations when consecutive timestep embeddings are similar.
The backend applies TeaCache hooks to the transformer which intercept the forward pass and implement the caching logic transparently.
Example
from vllm_omni.diffusion.data import DiffusionCacheConfig backend = TeaCacheBackend(DiffusionCacheConfig(rel_l1_thresh=0.2)) backend.enable(pipeline)
Generate with cache enabled¶
backend.refresh(pipeline, num_inference_steps=50) # Refresh before each generation
Access config attributes: backend.config.rel_l1_thresh¶
enable ¶
enable(pipeline: Any) -> None
Enable TeaCache on transformer using hooks.
This creates a TeaCacheConfig from the backend's DiffusionCacheConfig and applies the TeaCache hook to the transformer.
Parameters:
| Name | Type | Description | Default |
|---|---|---|---|
pipeline | Any | Diffusion pipeline instance. Extracts transformer and transformer_type: - transformer: pipeline.transformer - transformer_type: pipeline.transformer.class.name | required |
refresh ¶
Refresh TeaCache state for new generation.
Clears all cached residuals and resets counters/accumulators. Should be called before each generation to ensure clean state.
Parameters:
| Name | Type | Description | Default |
|---|---|---|---|
pipeline | Any | Diffusion pipeline instance. Extracts transformer via pipeline.transformer. | required |
num_inference_steps | int | Number of inference steps for the current generation. Currently not used by TeaCache but accepted for interface consistency. | required |
verbose | bool | Whether to log refresh operations (default: True) | True |
TeaCacheConfig dataclass ¶
Configuration for TeaCache applied to transformer models.
TeaCache (Timestep Embedding Aware Cache) is an adaptive caching technique that speeds up diffusion model inference by reusing transformer block computations when consecutive timestep embeddings are similar.
Parameters:
| Name | Type | Description | Default |
|---|---|---|---|
rel_l1_thresh | float | Threshold for accumulated relative L1 distance. When below threshold, cached residual is reused. Values in [0.1, 0.3] work best: - 0.2: ~1.5x speedup with minimal quality loss - 0.4: ~1.8x speedup with slight quality loss - 0.6: ~2.0x speedup with noticeable quality loss | 0.2 |
coefficients | list[float] | None | Polynomial coefficients for rescaling L1 distance. If None, uses model-specific defaults based on transformer_type. | None |
transformer_type | str | Transformer class name (e.g., "QwenImageTransformer2DModel"). Auto-detected from pipeline.transformer.class.name in backend. Defaults to "QwenImageTransformer2DModel". | 'QwenImageTransformer2DModel' |
TeaCacheHook ¶
Bases: ModelHook
ModelHook implementing TeaCache for transformer models.
This hook completely intercepts the transformer's forward pass and implements adaptive caching based on timestep embedding similarity. It's model-agnostic and supports multiple model types through extractor functions.
Key features: - Zero changes to model code - CFG-aware with separate states for positive/negative branches - CFG-parallel compatible: properly detects branch identity across ranks - Model-specific polynomial rescaling - Auto-detection of model types
Attributes:
| Name | Type | Description |
|---|---|---|
config | TeaCache configuration with thresholds and callbacks | |
rescale_func | Polynomial function for rescaling L1 distances | |
state_manager | Manages TeaCacheState across forward passes | |
extractor_fn | Model-specific function to extract modulated input |
initialize_hook ¶
Initialize hook with extractor from config transformer model type.
Parameters:
| Name | Type | Description | Default |
|---|---|---|---|
module | Module | The module to initialize the hook for. | required |
Returns:
| Type | Description |
|---|---|
Module | The initialized module. |
new_forward ¶
Generic forward handler that works for ANY model.
This method is completely model-agnostic. All model-specific logic is encapsulated in the extractor function that returns a CacheContext.
The extractor does: - Model-specific preprocessing - Extraction of modulated input for cache decision - Providing transformer execution callable - Providing postprocessing callable
This hook does: - CFG-aware state management - Cache decision logic (generic) - Residual caching and reuse
Parameters:
| Name | Type | Description | Default |
|---|---|---|---|
module | Module | Transformer module (any architecture) | required |
*args | Any | Positional arguments for model forward | () |
**kwargs | Any | Keyword arguments for model forward | {} |
Returns:
| Type | Description |
|---|---|
Any | Model output (format depends on model) |
reset_state ¶
Reset all cached states for a new inference run.
Parameters:
| Name | Type | Description | Default |
|---|---|---|---|
module | Module | The module to reset state for. | required |
Returns:
| Type | Description |
|---|---|
Module | The module with reset state. |
TeaCacheState ¶
State management for TeaCache hook.
Tracks caching state across diffusion timesteps, managing counters, accumulated distances, and cached residuals for the TeaCache algorithm.
apply_teacache_hook ¶
apply_teacache_hook(
module: Module, config: TeaCacheConfig
) -> None
Apply TeaCache optimization to a transformer module.
This function registers a TeaCacheHook that completely intercepts the module's forward pass, implementing adaptive caching without any changes to the model code.
Parameters:
| Name | Type | Description | Default |
|---|---|---|---|
module | Module | Transformer model to optimize (e.g., QwenImageTransformer2DModel) | required |
config | TeaCacheConfig | TeaCacheConfig specifying caching parameters | required |
Example
config = TeaCacheConfig( ... rel_l1_thresh=0.2, ... transformer_type="QwenImageTransformer2DModel" ... ) apply_teacache_hook(transformer, config)
Transformer bound to the pipeline now uses TeaCache automatically,¶
... # no code changes needed!
register_extractor ¶
Register a new extractor function for a model type.
This allows extending TeaCache support to new models without modifying the core TeaCache code.
Parameters:
| Name | Type | Description | Default |
|---|---|---|---|
transformer_cls_name | str | Transformer model type identifier (class name or type string) | required |
extractor_fn | Callable | Function with signature (module, args, *kwargs) -> CacheContext | required |
Example
def extract_flux_context(module, hidden_states, timestep, guidance=None, **kwargs): ... # Preprocessing ... temb = module.time_text_embed(timestep, guidance) ... # Extract modulated input ... modulated = module.transformer_blocks[0].norm1(hidden_states, emb=temb) ... # Define execution ... def run_blocks(): ... h = hidden_states ... for block in module.transformer_blocks: ... h = block(h, temb=temb) ... return (h,) ... # Define postprocessing ... def postprocess(h): ... return module.proj_out(module.norm_out(h, temb)) ... # Return context ... return CacheContext(modulated, hidden_states, None, temb, run_blocks, postprocess) register_extractor("FluxTransformer2DModel", extract_flux_context)