Skip to content

vllm_omni.diffusion.cache.teacache.extractors

Model-specific extractors for TeaCache.

This module provides a registry of extractor functions that know how to extract modulated inputs from different transformer architectures. Adding support for a new model requires only adding a new extractor function to the registry.

With Option B enhancement, extractors now return a CacheContext object containing all model-specific information needed for generic caching, including preprocessing, transformer execution, and postprocessing logic.

EXTRACTOR_REGISTRY module-attribute

EXTRACTOR_REGISTRY: dict[str, Callable] = {
    "QwenImageTransformer2DModel": extract_qwen_context,
    "Bagel": extract_bagel_context,
    "ZImageTransformer2DModel": extract_zimage_context,
    "Flux2Klein": extract_flux2_klein_context,
    "StableAudioDiTModel": extract_stable_audio_context,
    "Flux2Transformer2DModel": extract_flux2_context,
    "LongCatImageTransformer2DModel": extract_longcat_context,
    "FluxTransformer2DModel": extract_flux_context,
}

logger module-attribute

logger = init_logger(__name__)

CacheContext dataclass

Context object containing all model-specific information for caching.

This allows the TeaCacheHook to remain completely generic - all model-specific logic is encapsulated in the extractor that returns this context.

Attributes:

Name Type Description
modulated_input Tensor

Tensor used for cache decision (similarity comparison). Must be a torch.Tensor extracted from the first transformer block, typically after applying normalization and modulation.

hidden_states Tensor

Current hidden states (will be modified by caching). Must be a torch.Tensor representing the main image/latent states after preprocessing but before transformer blocks.

encoder_hidden_states Tensor | None

Optional encoder states (for dual-stream models). Set to None for single-stream models (e.g., Flux). For dual-stream models (e.g., Qwen), contains text encoder outputs.

temb Tensor

Timestep embedding tensor. Must be a torch.Tensor containing the timestep conditioning.

run_transformer_blocks Callable[[], tuple[Tensor, ...]]

Callable that executes model-specific transformer blocks. Signature: () -> tuple[torch.Tensor, ...]

Returns: tuple containing: - [0]: processed hidden_states (required) - [1]: processed encoder_hidden_states (optional, only for dual-stream)

Example for single-stream: def run_blocks(): h = hidden_states for block in module.transformer_blocks: h = block(h, temb=temb) return (h,)

Example for dual-stream: def run_blocks(): h, e = hidden_states, encoder_hidden_states for block in module.transformer_blocks: e, h = block(h, e, temb=temb) return (h, e)

postprocess Callable[[Tensor], Any]

Callable that does model-specific output postprocessing. Signature: (torch.Tensor) -> Union[torch.Tensor, Transformer2DModelOutput, tuple]

Takes the processed hidden_states and applies final transformations (normalization, projection) to produce the model output.

Example: def postprocess(h): h = module.norm_out(h, temb) output = module.proj_out(h) return Transformer2DModelOutput(sample=output)

extra_states dict[str, Any] | None

Optional dict for additional model-specific state. Use this for models that need to pass additional context beyond the standard fields.

encoder_hidden_states instance-attribute

encoder_hidden_states: Tensor | None

extra_states class-attribute instance-attribute

extra_states: dict[str, Any] | None = None

hidden_states instance-attribute

hidden_states: Tensor

modulated_input instance-attribute

modulated_input: Tensor

postprocess instance-attribute

postprocess: Callable[[Tensor], Any]

run_transformer_blocks instance-attribute

run_transformer_blocks: Callable[[], tuple[Tensor, ...]]

temb instance-attribute

temb: Tensor

validate

validate() -> None

Validate that the CacheContext contains valid data.

Raises:

Type Description
TypeError

If fields have wrong types

ValueError

If tensors have invalid properties

RuntimeError

If callables fail basic invocation tests

This method should be called after creating a CacheContext to catch common developer errors early with clear error messages.

extract_bagel_context

extract_bagel_context(
    module: Module,
    x_t: Tensor,
    timestep: Tensor | float | int,
    packed_vae_token_indexes: LongTensor,
    packed_vae_position_ids: LongTensor,
    packed_text_ids: LongTensor,
    packed_text_indexes: LongTensor,
    packed_position_ids: LongTensor,
    packed_seqlens: IntTensor,
    past_key_values: Any,
    **kwargs: Any,
) -> CacheContext

Extract cache context for Bagel model.

Parameters:

Name Type Description Default
module Module

Bagel instance

required
x_t Tensor

Latent image input

required
timestep Tensor | float | int

Current timestep

required
packed_vae_token_indexes LongTensor

Indexes for VAE tokens in packed sequence

required
packed_vae_position_ids LongTensor

Position IDs for VAE tokens

required
packed_text_ids LongTensor

Text token IDs

required
packed_text_indexes LongTensor

Indexes for text tokens in packed sequence

required
packed_position_ids LongTensor

Global position IDs

required
packed_seqlens IntTensor

Sequence lengths

required
past_key_values Any

KV cache

required
**kwargs Any

Additional keyword arguments

{}

Returns:

Type Description
CacheContext

CacheContext with all information needed for generic caching

extract_flux2_context

extract_flux2_context(
    module: Module,
    hidden_states: Tensor,
    encoder_hidden_states: Tensor = None,
    timestep: LongTensor = None,
    img_ids: Tensor = None,
    txt_ids: Tensor = None,
    guidance: Tensor | None = None,
    joint_attention_kwargs: dict[str, Any] | None = None,
    return_dict: bool = True,
    **kwargs: Any,
) -> CacheContext

Extract cache context for Flux2Transformer2DModel.

This is the ONLY Flux2-specific code needed for TeaCache support. It encapsulates preprocessing, modulated input extraction, transformer execution, and postprocessing logic.

Parameters:

Name Type Description Default
module Module

Flux2Transformer2DModel instance

required
hidden_states Tensor

Input hidden states tensor

required
encoder_hidden_states Tensor

Text encoder outputs

None
timestep LongTensor

Current diffusion timestep

None
img_ids Tensor

Image inputs for position embedding

None
txt_ids Tensor

Text inputs for position embedding

None
guidance Tensor | None

Optional guidance scale for CFG

None
joint_attention_kwargs dict[str, Any] | None

Additional attention arguments

None
return_dict bool

Whether to return a Transformer2DModelOutput instead of a plain tensor

True
**kwargs Any

Additional keyword arguments ignored by this extractor

{}

Returns:

Type Description
CacheContext

CacheContext with all information needed for generic caching

extract_flux2_klein_context

extract_flux2_klein_context(
    module: Module,
    hidden_states: Tensor,
    encoder_hidden_states: Tensor | None = None,
    timestep: LongTensor = None,
    img_ids: Tensor = None,
    txt_ids: Tensor = None,
    guidance: Tensor | None = None,
    joint_attention_kwargs: dict[str, Any] | None = None,
    **kwargs: Any,
) -> CacheContext

Extract cache context for Flux2Klein model.

Caches the full transformer output (including single_transformer_blocks). When cache is reused, single_transformer_blocks is skipped to achieve maximum speedup.

Parameters:

Name Type Description Default
module Module

Flux2Transformer2DModel instance

required
hidden_states Tensor

Input image hidden states tensor

required
encoder_hidden_states Tensor | None

Input text hidden states tensor

None
timestep LongTensor

Current diffusion timestep

None
img_ids Tensor

Image position IDs for RoPE

None
txt_ids Tensor

Text position IDs for RoPE

None
guidance Tensor | None

Optional guidance scale for CFG

None
joint_attention_kwargs dict[str, Any] | None

Additional attention kwargs

None

Returns:

Type Description
CacheContext

CacheContext with all information needed for generic caching

extract_flux_context

extract_flux_context(
    module: Module,
    hidden_states: Tensor,
    encoder_hidden_states: Tensor,
    pooled_projections: Tensor,
    timestep: LongTensor,
    img_ids: Tensor,
    txt_ids: Tensor,
    guidance: Tensor | None = None,
    joint_attention_kwargs: dict[str, Any] | None = None,
    **kwargs: Any,
) -> CacheContext

Extract cache context for FluxTransformer2DModel.

This mirrors the standard FLUX.1 transformer forward path while exposing the first modulated image stream tensor as TeaCache's similarity signal.

Parameters:

Name Type Description Default
module Module

FluxTransformer2DModel instance

required
hidden_states Tensor

Input image hidden states tensor

required
encoder_hidden_states Tensor

Text encoder outputs

required
pooled_projections Tensor

Pooled text conditioning

required
timestep LongTensor

Current diffusion timestep

required
img_ids Tensor

Image position IDs for RoPE

required
txt_ids Tensor

Text position IDs for RoPE

required
guidance Tensor | None

Optional guidance scale for guidance-distilled variants

None
joint_attention_kwargs dict[str, Any] | None

Additional attention kwargs

None
**kwargs Any

Additional keyword arguments

{}

Returns:

Type Description
CacheContext

CacheContext with all information needed for generic caching

extract_longcat_context

extract_longcat_context(
    module: Module,
    hidden_states,
    timestep,
    guidance,
    encoder_hidden_states,
    txt_ids,
    img_ids,
    **kwargs,
) -> CacheContext

Extract the cache context for LongCat Image.

Similar to other extractors, this is currently the only code needed for TeaCache support for LongCat image, and encapsulates preprocessing, modulated input extraction, transformer execution, and postprocessing logic.

Args & kawrgs are identical to the inputs to LongCat Image's forward.

Returns:

Type Description
CacheContext

CacheContext with all information needed for generic caching

extract_qwen_context

extract_qwen_context(
    module: Module,
    hidden_states: Tensor,
    encoder_hidden_states: Tensor,
    encoder_hidden_states_mask: Tensor,
    timestep: Tensor | float | int,
    img_shapes: Tensor,
    txt_seq_lens: Tensor,
    guidance: Tensor | None = None,
    additional_t_cond: Tensor | None = None,
    attention_kwargs: dict[str, Any] | None = None,
    **kwargs: Any,
) -> CacheContext

Extract cache context for QwenImageTransformer2DModel.

This is the ONLY Qwen-specific code needed for TeaCache support. It encapsulates preprocessing, modulated input extraction, transformer execution, and postprocessing logic.

Parameters:

Name Type Description Default
module Module

QwenImageTransformer2DModel instance

required
hidden_states Tensor

Input hidden states tensor

required
encoder_hidden_states Tensor

Text encoder outputs

required
encoder_hidden_states_mask Tensor

Mask for text encoder

required
timestep Tensor | float | int

Current diffusion timestep

required
img_shapes Tensor

Image shapes for position embedding

required
txt_seq_lens Tensor

Text sequence lengths

required
guidance Tensor | None

Optional guidance scale for CFG

None
additional_t_cond Tensor | None

Optional additional timestep conditioning

None
attention_kwargs dict[str, Any] | None

Additional attention arguments

None
**kwargs Any

Additional keyword arguments ignored by this extractor

{}

Returns:

Type Description
CacheContext

CacheContext with all information needed for generic caching

extract_stable_audio_context

extract_stable_audio_context(
    module: Module,
    hidden_states: Tensor,
    timestep: Tensor,
    encoder_hidden_states: Tensor,
    global_hidden_states: Tensor | None = None,
    rotary_embedding: tuple[Tensor, Tensor] | None = None,
    return_dict: bool = True,
    attention_mask: Tensor | None = None,
    encoder_attention_mask: Tensor | None = None,
    **kwargs: Any,
) -> CacheContext

Extract cache context for Stable Audio DiT model.

Architecture Notes
  • Stable Audio uses standard LayerNorm
  • Timestep conditioning via global_hidden_states prepended to sequence
  • Single-stream model (cross-attention handled separately)
  • Input: [B, C, L] (C=in_channels) -> transpose -> [B, L, C] -> project -> [B, L, inner_dim]
  • Global states prepended: [B, 1+L, inner_dim]

extract_zimage_context

extract_zimage_context(
    module: Module,
    x: list[Tensor],
    t: Tensor,
    cap_feats: list[Tensor],
    patch_size: int = 2,
    f_patch_size: int = 1,
    **kwargs: Any,
) -> CacheContext

Extract cache context for ZImageTransformer2DModel.

This is the ONLY Z-Image-specific code needed for TeaCache support. It encapsulates preprocessing, modulated input extraction, transformer execution, and postprocessing logic.

Parameters:

Name Type Description Default
module Module

ZImageTransformer2DModel instance

required
x list[Tensor]

List of image tensors per batch item

required
t Tensor

Timestep tensor

required
cap_feats list[Tensor]

List of caption feature tensors per batch item

required
patch_size int

Patch size for patchification (default: 2)

2
f_patch_size int

Frame patch size (default: 1)

1
**kwargs Any

Additional keyword arguments ignored by this extractor

{}

Returns:

Type Description
CacheContext

CacheContext with all information needed for generic caching

get_extractor

get_extractor(transformer_cls_name: str) -> Callable

Get extractor function for given transformer class.

This function looks up the extractor based on the exact transformer_cls_name string, which should match the transformer type in the pipeline (i.e., pipeline.transformer.class.name).

Parameters:

Name Type Description Default
transformer_cls_name str

Transformer class name (e.g., "QwenImageTransformer2DModel") Must exactly match a key in EXTRACTOR_REGISTRY.

required

Returns:

Type Description
Callable

Extractor function with signature (module, args, *kwargs) -> CacheContext

Raises:

Type Description
ValueError

If model type not found in registry

Example

Get extractor for QwenImageTransformer2DModel

extractor = get_extractor("QwenImageTransformer2DModel") ctx = extractor(transformer, hidden_states, encoder_hidden_states, timestep, ...)

register_extractor

register_extractor(
    transformer_cls_name: str, extractor_fn: Callable
) -> None

Register a new extractor function for a model type.

This allows extending TeaCache support to new models without modifying the core TeaCache code.

Parameters:

Name Type Description Default
transformer_cls_name str

Transformer model type identifier (class name or type string)

required
extractor_fn Callable

Function with signature (module, args, *kwargs) -> CacheContext

required
Example

def extract_flux_context(module, hidden_states, timestep, guidance=None, **kwargs): ... # Preprocessing ... temb = module.time_text_embed(timestep, guidance) ... # Extract modulated input ... modulated = module.transformer_blocks[0].norm1(hidden_states, emb=temb) ... # Define execution ... def run_blocks(): ... h = hidden_states ... for block in module.transformer_blocks: ... h = block(h, temb=temb) ... return (h,) ... # Define postprocessing ... def postprocess(h): ... return module.proj_out(module.norm_out(h, temb)) ... # Return context ... return CacheContext(modulated, hidden_states, None, temb, run_blocks, postprocess) register_extractor("FluxTransformer2DModel", extract_flux_context)