vllm_omni.diffusion.cache.teacache.extractors ¶
Model-specific extractors for TeaCache.
This module provides a registry of extractor functions that know how to extract modulated inputs from different transformer architectures. Adding support for a new model requires only adding a new extractor function to the registry.
With Option B enhancement, extractors now return a CacheContext object containing all model-specific information needed for generic caching, including preprocessing, transformer execution, and postprocessing logic.
EXTRACTOR_REGISTRY module-attribute ¶
EXTRACTOR_REGISTRY: dict[str, Callable] = {
"QwenImageTransformer2DModel": extract_qwen_context,
"Bagel": extract_bagel_context,
"ZImageTransformer2DModel": extract_zimage_context,
"Flux2Klein": extract_flux2_klein_context,
"StableAudioDiTModel": extract_stable_audio_context,
"Flux2Transformer2DModel": extract_flux2_context,
"LongCatImageTransformer2DModel": extract_longcat_context,
"FluxTransformer2DModel": extract_flux_context,
}
CacheContext dataclass ¶
Context object containing all model-specific information for caching.
This allows the TeaCacheHook to remain completely generic - all model-specific logic is encapsulated in the extractor that returns this context.
Attributes:
| Name | Type | Description |
|---|---|---|
modulated_input | Tensor | Tensor used for cache decision (similarity comparison). Must be a torch.Tensor extracted from the first transformer block, typically after applying normalization and modulation. |
hidden_states | Tensor | Current hidden states (will be modified by caching). Must be a torch.Tensor representing the main image/latent states after preprocessing but before transformer blocks. |
encoder_hidden_states | Tensor | None | Optional encoder states (for dual-stream models). Set to None for single-stream models (e.g., Flux). For dual-stream models (e.g., Qwen), contains text encoder outputs. |
temb | Tensor | Timestep embedding tensor. Must be a torch.Tensor containing the timestep conditioning. |
run_transformer_blocks | Callable[[], tuple[Tensor, ...]] | Callable that executes model-specific transformer blocks. Signature: () -> tuple[torch.Tensor, ...] Returns: tuple containing: - [0]: processed hidden_states (required) - [1]: processed encoder_hidden_states (optional, only for dual-stream) Example for single-stream: def run_blocks(): h = hidden_states for block in module.transformer_blocks: h = block(h, temb=temb) return (h,) Example for dual-stream: def run_blocks(): h, e = hidden_states, encoder_hidden_states for block in module.transformer_blocks: e, h = block(h, e, temb=temb) return (h, e) |
postprocess | Callable[[Tensor], Any] | Callable that does model-specific output postprocessing. Signature: (torch.Tensor) -> Union[torch.Tensor, Transformer2DModelOutput, tuple] Takes the processed hidden_states and applies final transformations (normalization, projection) to produce the model output. Example: def postprocess(h): h = module.norm_out(h, temb) output = module.proj_out(h) return Transformer2DModelOutput(sample=output) |
extra_states | dict[str, Any] | None | Optional dict for additional model-specific state. Use this for models that need to pass additional context beyond the standard fields. |
run_transformer_blocks instance-attribute ¶
validate ¶
Validate that the CacheContext contains valid data.
Raises:
| Type | Description |
|---|---|
TypeError | If fields have wrong types |
ValueError | If tensors have invalid properties |
RuntimeError | If callables fail basic invocation tests |
This method should be called after creating a CacheContext to catch common developer errors early with clear error messages.
extract_bagel_context ¶
extract_bagel_context(
module: Module,
x_t: Tensor,
timestep: Tensor | float | int,
packed_vae_token_indexes: LongTensor,
packed_vae_position_ids: LongTensor,
packed_text_ids: LongTensor,
packed_text_indexes: LongTensor,
packed_position_ids: LongTensor,
packed_seqlens: IntTensor,
past_key_values: Any,
**kwargs: Any,
) -> CacheContext
Extract cache context for Bagel model.
Parameters:
| Name | Type | Description | Default |
|---|---|---|---|
module | Module | Bagel instance | required |
x_t | Tensor | Latent image input | required |
timestep | Tensor | float | int | Current timestep | required |
packed_vae_token_indexes | LongTensor | Indexes for VAE tokens in packed sequence | required |
packed_vae_position_ids | LongTensor | Position IDs for VAE tokens | required |
packed_text_ids | LongTensor | Text token IDs | required |
packed_text_indexes | LongTensor | Indexes for text tokens in packed sequence | required |
packed_position_ids | LongTensor | Global position IDs | required |
packed_seqlens | IntTensor | Sequence lengths | required |
past_key_values | Any | KV cache | required |
**kwargs | Any | Additional keyword arguments | {} |
Returns:
| Type | Description |
|---|---|
CacheContext | CacheContext with all information needed for generic caching |
extract_flux2_context ¶
extract_flux2_context(
module: Module,
hidden_states: Tensor,
encoder_hidden_states: Tensor = None,
timestep: LongTensor = None,
img_ids: Tensor = None,
txt_ids: Tensor = None,
guidance: Tensor | None = None,
joint_attention_kwargs: dict[str, Any] | None = None,
return_dict: bool = True,
**kwargs: Any,
) -> CacheContext
Extract cache context for Flux2Transformer2DModel.
This is the ONLY Flux2-specific code needed for TeaCache support. It encapsulates preprocessing, modulated input extraction, transformer execution, and postprocessing logic.
Parameters:
| Name | Type | Description | Default |
|---|---|---|---|
module | Module | Flux2Transformer2DModel instance | required |
hidden_states | Tensor | Input hidden states tensor | required |
encoder_hidden_states | Tensor | Text encoder outputs | None |
timestep | LongTensor | Current diffusion timestep | None |
img_ids | Tensor | Image inputs for position embedding | None |
txt_ids | Tensor | Text inputs for position embedding | None |
guidance | Tensor | None | Optional guidance scale for CFG | None |
joint_attention_kwargs | dict[str, Any] | None | Additional attention arguments | None |
return_dict | bool | Whether to return a Transformer2DModelOutput instead of a plain tensor | True |
**kwargs | Any | Additional keyword arguments ignored by this extractor | {} |
Returns:
| Type | Description |
|---|---|
CacheContext | CacheContext with all information needed for generic caching |
extract_flux2_klein_context ¶
extract_flux2_klein_context(
module: Module,
hidden_states: Tensor,
encoder_hidden_states: Tensor | None = None,
timestep: LongTensor = None,
img_ids: Tensor = None,
txt_ids: Tensor = None,
guidance: Tensor | None = None,
joint_attention_kwargs: dict[str, Any] | None = None,
**kwargs: Any,
) -> CacheContext
Extract cache context for Flux2Klein model.
Caches the full transformer output (including single_transformer_blocks). When cache is reused, single_transformer_blocks is skipped to achieve maximum speedup.
Parameters:
| Name | Type | Description | Default |
|---|---|---|---|
module | Module | Flux2Transformer2DModel instance | required |
hidden_states | Tensor | Input image hidden states tensor | required |
encoder_hidden_states | Tensor | None | Input text hidden states tensor | None |
timestep | LongTensor | Current diffusion timestep | None |
img_ids | Tensor | Image position IDs for RoPE | None |
txt_ids | Tensor | Text position IDs for RoPE | None |
guidance | Tensor | None | Optional guidance scale for CFG | None |
joint_attention_kwargs | dict[str, Any] | None | Additional attention kwargs | None |
Returns:
| Type | Description |
|---|---|
CacheContext | CacheContext with all information needed for generic caching |
extract_flux_context ¶
extract_flux_context(
module: Module,
hidden_states: Tensor,
encoder_hidden_states: Tensor,
pooled_projections: Tensor,
timestep: LongTensor,
img_ids: Tensor,
txt_ids: Tensor,
guidance: Tensor | None = None,
joint_attention_kwargs: dict[str, Any] | None = None,
**kwargs: Any,
) -> CacheContext
Extract cache context for FluxTransformer2DModel.
This mirrors the standard FLUX.1 transformer forward path while exposing the first modulated image stream tensor as TeaCache's similarity signal.
Parameters:
| Name | Type | Description | Default |
|---|---|---|---|
module | Module | FluxTransformer2DModel instance | required |
hidden_states | Tensor | Input image hidden states tensor | required |
encoder_hidden_states | Tensor | Text encoder outputs | required |
pooled_projections | Tensor | Pooled text conditioning | required |
timestep | LongTensor | Current diffusion timestep | required |
img_ids | Tensor | Image position IDs for RoPE | required |
txt_ids | Tensor | Text position IDs for RoPE | required |
guidance | Tensor | None | Optional guidance scale for guidance-distilled variants | None |
joint_attention_kwargs | dict[str, Any] | None | Additional attention kwargs | None |
**kwargs | Any | Additional keyword arguments | {} |
Returns:
| Type | Description |
|---|---|
CacheContext | CacheContext with all information needed for generic caching |
extract_longcat_context ¶
extract_longcat_context(
module: Module,
hidden_states,
timestep,
guidance,
encoder_hidden_states,
txt_ids,
img_ids,
**kwargs,
) -> CacheContext
Extract the cache context for LongCat Image.
Similar to other extractors, this is currently the only code needed for TeaCache support for LongCat image, and encapsulates preprocessing, modulated input extraction, transformer execution, and postprocessing logic.
Args & kawrgs are identical to the inputs to LongCat Image's forward.
Returns:
| Type | Description |
|---|---|
CacheContext | CacheContext with all information needed for generic caching |
extract_qwen_context ¶
extract_qwen_context(
module: Module,
hidden_states: Tensor,
encoder_hidden_states: Tensor,
encoder_hidden_states_mask: Tensor,
timestep: Tensor | float | int,
img_shapes: Tensor,
txt_seq_lens: Tensor,
guidance: Tensor | None = None,
additional_t_cond: Tensor | None = None,
attention_kwargs: dict[str, Any] | None = None,
**kwargs: Any,
) -> CacheContext
Extract cache context for QwenImageTransformer2DModel.
This is the ONLY Qwen-specific code needed for TeaCache support. It encapsulates preprocessing, modulated input extraction, transformer execution, and postprocessing logic.
Parameters:
| Name | Type | Description | Default |
|---|---|---|---|
module | Module | QwenImageTransformer2DModel instance | required |
hidden_states | Tensor | Input hidden states tensor | required |
encoder_hidden_states | Tensor | Text encoder outputs | required |
encoder_hidden_states_mask | Tensor | Mask for text encoder | required |
timestep | Tensor | float | int | Current diffusion timestep | required |
img_shapes | Tensor | Image shapes for position embedding | required |
txt_seq_lens | Tensor | Text sequence lengths | required |
guidance | Tensor | None | Optional guidance scale for CFG | None |
additional_t_cond | Tensor | None | Optional additional timestep conditioning | None |
attention_kwargs | dict[str, Any] | None | Additional attention arguments | None |
**kwargs | Any | Additional keyword arguments ignored by this extractor | {} |
Returns:
| Type | Description |
|---|---|
CacheContext | CacheContext with all information needed for generic caching |
extract_stable_audio_context ¶
extract_stable_audio_context(
module: Module,
hidden_states: Tensor,
timestep: Tensor,
encoder_hidden_states: Tensor,
global_hidden_states: Tensor | None = None,
rotary_embedding: tuple[Tensor, Tensor] | None = None,
return_dict: bool = True,
attention_mask: Tensor | None = None,
encoder_attention_mask: Tensor | None = None,
**kwargs: Any,
) -> CacheContext
Extract cache context for Stable Audio DiT model.
Architecture Notes
- Stable Audio uses standard LayerNorm
- Timestep conditioning via global_hidden_states prepended to sequence
- Single-stream model (cross-attention handled separately)
- Input: [B, C, L] (C=in_channels) -> transpose -> [B, L, C] -> project -> [B, L, inner_dim]
- Global states prepended: [B, 1+L, inner_dim]
extract_zimage_context ¶
extract_zimage_context(
module: Module,
x: list[Tensor],
t: Tensor,
cap_feats: list[Tensor],
patch_size: int = 2,
f_patch_size: int = 1,
**kwargs: Any,
) -> CacheContext
Extract cache context for ZImageTransformer2DModel.
This is the ONLY Z-Image-specific code needed for TeaCache support. It encapsulates preprocessing, modulated input extraction, transformer execution, and postprocessing logic.
Parameters:
| Name | Type | Description | Default |
|---|---|---|---|
module | Module | ZImageTransformer2DModel instance | required |
x | list[Tensor] | List of image tensors per batch item | required |
t | Tensor | Timestep tensor | required |
cap_feats | list[Tensor] | List of caption feature tensors per batch item | required |
patch_size | int | Patch size for patchification (default: 2) | 2 |
f_patch_size | int | Frame patch size (default: 1) | 1 |
**kwargs | Any | Additional keyword arguments ignored by this extractor | {} |
Returns:
| Type | Description |
|---|---|
CacheContext | CacheContext with all information needed for generic caching |
get_extractor ¶
Get extractor function for given transformer class.
This function looks up the extractor based on the exact transformer_cls_name string, which should match the transformer type in the pipeline (i.e., pipeline.transformer.class.name).
Parameters:
| Name | Type | Description | Default |
|---|---|---|---|
transformer_cls_name | str | Transformer class name (e.g., "QwenImageTransformer2DModel") Must exactly match a key in EXTRACTOR_REGISTRY. | required |
Returns:
| Type | Description |
|---|---|
Callable | Extractor function with signature (module, args, *kwargs) -> CacheContext |
Raises:
| Type | Description |
|---|---|
ValueError | If model type not found in registry |
Example
Get extractor for QwenImageTransformer2DModel¶
extractor = get_extractor("QwenImageTransformer2DModel") ctx = extractor(transformer, hidden_states, encoder_hidden_states, timestep, ...)
register_extractor ¶
Register a new extractor function for a model type.
This allows extending TeaCache support to new models without modifying the core TeaCache code.
Parameters:
| Name | Type | Description | Default |
|---|---|---|---|
transformer_cls_name | str | Transformer model type identifier (class name or type string) | required |
extractor_fn | Callable | Function with signature (module, args, *kwargs) -> CacheContext | required |
Example
def extract_flux_context(module, hidden_states, timestep, guidance=None, **kwargs): ... # Preprocessing ... temb = module.time_text_embed(timestep, guidance) ... # Extract modulated input ... modulated = module.transformer_blocks[0].norm1(hidden_states, emb=temb) ... # Define execution ... def run_blocks(): ... h = hidden_states ... for block in module.transformer_blocks: ... h = block(h, temb=temb) ... return (h,) ... # Define postprocessing ... def postprocess(h): ... return module.proj_out(module.norm_out(h, temb)) ... # Return context ... return CacheContext(modulated, hidden_states, None, temb, run_blocks, postprocess) register_extractor("FluxTransformer2DModel", extract_flux_context)