Skip to content

vllm_omni.diffusion.cache.teacache

TeaCache: Timestep Embedding Aware Cache for diffusion model acceleration.

TeaCache speeds up diffusion inference by reusing transformer block computations when consecutive timestep embeddings are similar.

This implementation uses a hooks-based approach that requires zero changes to model code. Model developers only need to add an extractor function to support new models.

Usage

from vllm_omni import Omni

omni = Omni( model="Qwen/Qwen-Image", cache_backend="tea_cache", cache_config={"rel_l1_thresh": 0.2} ) images = omni.generate("a cat")

Alternative: Using environment variable

export DIFFUSION_CACHE_BACKEND=tea_cache

Modules:

Name Description
backend

TeaCache backend implementation.

coefficient_estimator
config
extractors

Model-specific extractors for TeaCache.

hook

Hook-based TeaCache implementation for vLLM-Omni.

state

TeaCache state management.

CacheContext dataclass

Context object containing all model-specific information for caching.

This allows the TeaCacheHook to remain completely generic - all model-specific logic is encapsulated in the extractor that returns this context.

Attributes:

Name Type Description
modulated_input Tensor

Tensor used for cache decision (similarity comparison). Must be a torch.Tensor extracted from the first transformer block, typically after applying normalization and modulation.

hidden_states Tensor

Current hidden states (will be modified by caching). Must be a torch.Tensor representing the main image/latent states after preprocessing but before transformer blocks.

encoder_hidden_states Tensor | None

Optional encoder states (for dual-stream models). Set to None for single-stream models (e.g., Flux). For dual-stream models (e.g., Qwen), contains text encoder outputs.

temb Tensor

Timestep embedding tensor. Must be a torch.Tensor containing the timestep conditioning.

run_transformer_blocks Callable[[], tuple[Tensor, ...]]

Callable that executes model-specific transformer blocks. Signature: () -> tuple[torch.Tensor, ...]

Returns: tuple containing: - [0]: processed hidden_states (required) - [1]: processed encoder_hidden_states (optional, only for dual-stream)

Example for single-stream: def run_blocks(): h = hidden_states for block in module.transformer_blocks: h = block(h, temb=temb) return (h,)

Example for dual-stream: def run_blocks(): h, e = hidden_states, encoder_hidden_states for block in module.transformer_blocks: e, h = block(h, e, temb=temb) return (h, e)

postprocess Callable[[Tensor], Any]

Callable that does model-specific output postprocessing. Signature: (torch.Tensor) -> Union[torch.Tensor, Transformer2DModelOutput, tuple]

Takes the processed hidden_states and applies final transformations (normalization, projection) to produce the model output.

Example: def postprocess(h): h = module.norm_out(h, temb) output = module.proj_out(h) return Transformer2DModelOutput(sample=output)

extra_states dict[str, Any] | None

Optional dict for additional model-specific state. Use this for models that need to pass additional context beyond the standard fields.

encoder_hidden_states instance-attribute

encoder_hidden_states: Tensor | None

extra_states class-attribute instance-attribute

extra_states: dict[str, Any] | None = None

hidden_states instance-attribute

hidden_states: Tensor

modulated_input instance-attribute

modulated_input: Tensor

postprocess instance-attribute

postprocess: Callable[[Tensor], Any]

run_transformer_blocks instance-attribute

run_transformer_blocks: Callable[[], tuple[Tensor, ...]]

temb instance-attribute

temb: Tensor

validate

validate() -> None

Validate that the CacheContext contains valid data.

Raises:

Type Description
TypeError

If fields have wrong types

ValueError

If tensors have invalid properties

RuntimeError

If callables fail basic invocation tests

This method should be called after creating a CacheContext to catch common developer errors early with clear error messages.

TeaCacheBackend

Bases: CacheBackend

TeaCache implementation using hooks.

TeaCache (Timestep Embedding Aware Cache) is an adaptive caching technique that speeds up diffusion inference by reusing transformer block computations when consecutive timestep embeddings are similar.

The backend applies TeaCache hooks to the transformer which intercept the forward pass and implement the caching logic transparently.

Example

from vllm_omni.diffusion.data import DiffusionCacheConfig backend = TeaCacheBackend(DiffusionCacheConfig(rel_l1_thresh=0.2)) backend.enable(pipeline)

Generate with cache enabled

backend.refresh(pipeline, num_inference_steps=50) # Refresh before each generation

Access config attributes: backend.config.rel_l1_thresh

enable

enable(pipeline: Any) -> None

Enable TeaCache on transformer using hooks.

This creates a TeaCacheConfig from the backend's DiffusionCacheConfig and applies the TeaCache hook to the transformer.

Parameters:

Name Type Description Default
pipeline Any

Diffusion pipeline instance. Extracts transformer and transformer_type: - transformer: pipeline.transformer - transformer_type: pipeline.transformer.class.name

required

refresh

refresh(
    pipeline: Any,
    num_inference_steps: int,
    verbose: bool = True,
) -> None

Refresh TeaCache state for new generation.

Clears all cached residuals and resets counters/accumulators. Should be called before each generation to ensure clean state.

Parameters:

Name Type Description Default
pipeline Any

Diffusion pipeline instance. Extracts transformer via pipeline.transformer.

required
num_inference_steps int

Number of inference steps for the current generation. Currently not used by TeaCache but accepted for interface consistency.

required
verbose bool

Whether to log refresh operations (default: True)

True

TeaCacheConfig dataclass

Configuration for TeaCache applied to transformer models.

TeaCache (Timestep Embedding Aware Cache) is an adaptive caching technique that speeds up diffusion model inference by reusing transformer block computations when consecutive timestep embeddings are similar.

Parameters:

Name Type Description Default
rel_l1_thresh float

Threshold for accumulated relative L1 distance. When below threshold, cached residual is reused. Values in [0.1, 0.3] work best: - 0.2: ~1.5x speedup with minimal quality loss - 0.4: ~1.8x speedup with slight quality loss - 0.6: ~2.0x speedup with noticeable quality loss

0.2
coefficients list[float] | None

Polynomial coefficients for rescaling L1 distance. If None, uses model-specific defaults based on transformer_type.

None
transformer_type str

Transformer class name (e.g., "QwenImageTransformer2DModel"). Auto-detected from pipeline.transformer.class.name in backend. Defaults to "QwenImageTransformer2DModel".

'QwenImageTransformer2DModel'

coefficients class-attribute instance-attribute

coefficients: list[float] | None = None

rel_l1_thresh class-attribute instance-attribute

rel_l1_thresh: float = 0.2

transformer_type class-attribute instance-attribute

transformer_type: str = 'QwenImageTransformer2DModel'

TeaCacheHook

Bases: ModelHook

ModelHook implementing TeaCache for transformer models.

This hook completely intercepts the transformer's forward pass and implements adaptive caching based on timestep embedding similarity. It's model-agnostic and supports multiple model types through extractor functions.

Key features: - Zero changes to model code - CFG-aware with separate states for positive/negative branches - CFG-parallel compatible: properly detects branch identity across ranks - Model-specific polynomial rescaling - Auto-detection of model types

Attributes:

Name Type Description
config

TeaCache configuration with thresholds and callbacks

rescale_func

Polynomial function for rescaling L1 distances

state_manager

Manages TeaCacheState across forward passes

extractor_fn

Model-specific function to extract modulated input

config instance-attribute

config = config

extractor_fn instance-attribute

extractor_fn = None

rescale_func instance-attribute

rescale_func = poly1d(coefficients)

state_manager instance-attribute

state_manager = StateManager(TeaCacheState)

initialize_hook

initialize_hook(module: Module) -> Module

Initialize hook with extractor from config transformer model type.

Parameters:

Name Type Description Default
module Module

The module to initialize the hook for.

required

Returns:

Type Description
Module

The initialized module.

new_forward

new_forward(
    module: Module, *args: Any, **kwargs: Any
) -> Any

Generic forward handler that works for ANY model.

This method is completely model-agnostic. All model-specific logic is encapsulated in the extractor function that returns a CacheContext.

The extractor does: - Model-specific preprocessing - Extraction of modulated input for cache decision - Providing transformer execution callable - Providing postprocessing callable

This hook does: - CFG-aware state management - Cache decision logic (generic) - Residual caching and reuse

Parameters:

Name Type Description Default
module Module

Transformer module (any architecture)

required
*args Any

Positional arguments for model forward

()
**kwargs Any

Keyword arguments for model forward

{}

Returns:

Type Description
Any

Model output (format depends on model)

reset_state

reset_state(module: Module) -> Module

Reset all cached states for a new inference run.

Parameters:

Name Type Description Default
module Module

The module to reset state for.

required

Returns:

Type Description
Module

The module with reset state.

TeaCacheState

State management for TeaCache hook.

Tracks caching state across diffusion timesteps, managing counters, accumulated distances, and cached residuals for the TeaCache algorithm.

accumulated_rel_l1_distance instance-attribute

accumulated_rel_l1_distance = 0.0

cnt instance-attribute

cnt = 0

previous_modulated_input instance-attribute

previous_modulated_input: Tensor | None = None

previous_residual instance-attribute

previous_residual: Tensor | None = None

previous_residual_encoder instance-attribute

previous_residual_encoder: Tensor | None = None

reset

reset() -> None

Reset all state variables for a new inference run.

apply_teacache_hook

apply_teacache_hook(
    module: Module, config: TeaCacheConfig
) -> None

Apply TeaCache optimization to a transformer module.

This function registers a TeaCacheHook that completely intercepts the module's forward pass, implementing adaptive caching without any changes to the model code.

Parameters:

Name Type Description Default
module Module

Transformer model to optimize (e.g., QwenImageTransformer2DModel)

required
config TeaCacheConfig

TeaCacheConfig specifying caching parameters

required
Example

config = TeaCacheConfig( ... rel_l1_thresh=0.2, ... transformer_type="QwenImageTransformer2DModel" ... ) apply_teacache_hook(transformer, config)

Transformer bound to the pipeline now uses TeaCache automatically,

... # no code changes needed!

register_extractor

register_extractor(
    transformer_cls_name: str, extractor_fn: Callable
) -> None

Register a new extractor function for a model type.

This allows extending TeaCache support to new models without modifying the core TeaCache code.

Parameters:

Name Type Description Default
transformer_cls_name str

Transformer model type identifier (class name or type string)

required
extractor_fn Callable

Function with signature (module, args, *kwargs) -> CacheContext

required
Example

def extract_flux_context(module, hidden_states, timestep, guidance=None, **kwargs): ... # Preprocessing ... temb = module.time_text_embed(timestep, guidance) ... # Extract modulated input ... modulated = module.transformer_blocks[0].norm1(hidden_states, emb=temb) ... # Define execution ... def run_blocks(): ... h = hidden_states ... for block in module.transformer_blocks: ... h = block(h, temb=temb) ... return (h,) ... # Define postprocessing ... def postprocess(h): ... return module.proj_out(module.norm_out(h, temb)) ... # Return context ... return CacheContext(modulated, hidden_states, None, temb, run_blocks, postprocess) register_extractor("FluxTransformer2DModel", extract_flux_context)