Cache-DiT¶
This section describes how to add cache-dit acceleration to a new diffusion pipeline. We use the Qwen-Image pipeline and LongCat-Image pipeline as reference implementations.
Table of Contents¶
- Overview
- Standard Models: Automatic Support
- Custom Architectures: Writing Custom Implementation
- Testing
- Troubleshooting
- Reference Implementations
- Summary
Overview¶
What is Cache-DiT?¶
Cache-DiT is an acceleration library for Diffusion Transformers (DiT) that caches intermediate computation results across denoising steps. The core insight is that adjacent denoising steps often produce similar intermediate features, so we can skip redundant computations by reusing cached results.
The library supports three main caching strategies:
- DBCache: Dynamic block-level caching that selectively computes or caches transformer blocks based on residual differences
- TaylorSeer: Calibration-based prediction that estimates block outputs using Taylor expansion
- SCM (Step Computation Masking): Dynamic step skipping based on configurable policies
Architecture¶
vLLM-omni integrates cache-dit through the CacheDiTBackend class, which provides a unified interface for managing cache-dit acceleration on diffusion models.
| Method/Class | Purpose | Behavior |
|---|---|---|
CacheDiTBackend | Unified backend interface | Automatically handles enabler selection and cache refresh |
enable_cache_for_dit() | Apply caching to transformer | Configures DBCache on transformer blocks |
Key APIs from Cache-DiT:
| API | Description |
|---|---|
BlockAdapter | Core abstraction for applying cache-dit to transformers. Specifies transformer module(s), block list(s), and forward signature pattern(s). |
ForwardPattern | Defines block forward signature patterns: Pattern_0, Pattern_1, Pattern_2 |
ParamsModifier | Per-transformer or per-block-list cache configuration customization |
DBCacheConfig | Configuration for DBCache parameters (warmup steps, cached steps, thresholds) |
refresh_context() | Update cache context |
Standard Models: Automatic Support¶
Most DiT models follow this pattern: - Single transformer with one ModuleList of blocks - Standard forward signature - Compatible with cache-dit's automatic detection
Examples: Qwen-Image, Z-Image
For standard single-transformer models, no code changes are needed. The CacheDiTBackend automatically uses enable_cache_for_dit():
from vllm_omni import Omni
# Works automatically for standard models
omni = Omni(
model="Qwen/Qwen-Image", # Standard single-transformer model
cache_backend="cache_dit",
cache_config={
"Fn_compute_blocks": 1,
"Bn_compute_blocks": 0,
"max_warmup_steps": 4,
}
)
What happens automatically:
def enable_cache_for_dit(pipeline: Any, cache_config: Any) -> Callable[[int], None]:
"""Default enabler for standard single-transformer DiT models."""
# Build cache configuration
db_cache_config = DBCacheConfig(
num_inference_steps=None, # Will be set during first inference
Fn_compute_blocks=cache_config.Fn_compute_blocks,
Bn_compute_blocks=cache_config.Bn_compute_blocks,
max_warmup_steps=cache_config.max_warmup_steps,
max_cached_steps=cache_config.max_cached_steps,
max_continuous_cached_steps=cache_config.max_continuous_cached_steps,
residual_diff_threshold=cache_config.residual_diff_threshold,
)
# Enable cache-dit on transformer
cache_dit.enable_cache(
pipeline.transformer,
cache_config=db_cache_config,
)
# Return refresh function for dynamic num_inference_steps updates
def refresh_cache_context(pipeline: Any, num_inference_steps: int, verbose: bool = True):
cache_dit.refresh_context(pipeline.transformer, num_inference_steps=num_inference_steps, verbose=verbose)
return refresh_cache_context
Custom Architectures: Writing Custom Implementation¶
Some models require custom handling:
- Single or dual-transformer: Models that may use one or two transformers (e.g., Wan2.2)
- Multi-block-list: Models with multiple block lists in one transformer (e.g., LongCatImage with
transformer_blocks+single_transformer_blocks) - Special forward patterns: Models with non-standard block execution patterns
Example 1: Single or Dual-Transformer Model (Wan2.2)¶
Wan2.2 can use either a single transformer or two transformers (one for high-noise steps and one for low-noise steps). The implementation automatically detects the mode based on the presence of transformer_2.
Key difference: Use BlockAdapter to wrap multiple transformers with separate configurations.
# Standard: cache_dit.enable_cache(pipeline.transformer, ...)
# Custom: Use BlockAdapter to handle multiple transformers
cache_dit.enable_cache(
BlockAdapter(
transformer=[pipeline.transformer, pipeline.transformer_2], # Multiple transformers
blocks=[pipeline.transformer.blocks, pipeline.transformer_2.blocks],
forward_pattern=[ForwardPattern.Pattern_2, ForwardPattern.Pattern_2],
params_modifiers=[
ParamsModifier(...), # Config for high-noise transformer
ParamsModifier(...), # Config for low-noise transformer (different params)
],
),
cache_config=db_cache_config,
)
Key difference: refresh_context must be called on each transformer separately.
# Standard: cache_dit.refresh_context(pipeline.transformer, num_inference_steps=N)
# Custom: Refresh each transformer with its own step count
def refresh_cache_context(pipeline, num_inference_steps, verbose=True):
high_steps, low_steps = _split_inference_steps(num_inference_steps)
cache_dit.refresh_context(pipeline.transformer, num_inference_steps=high_steps, ...)
cache_dit.refresh_context(pipeline.transformer_2, num_inference_steps=low_steps, ...)
Example 2: Multi-Block-List Model (LongCatImage)¶
LongCatImage has a single transformer with two block lists: transformer_blocks and single_transformer_blocks.
Key difference: Use BlockAdapter to specify multiple block lists within one transformer.
# Standard: cache_dit.enable_cache(pipeline.transformer, ...)
# - Automatically detects single block list
# Custom: Use BlockAdapter to specify multiple block lists
cache_dit.enable_cache(
BlockAdapter(
transformer=pipeline.transformer, # Single transformer
blocks=[
pipeline.transformer.transformer_blocks, # Block list 1
pipeline.transformer.single_transformer_blocks, # Block list 2
],
forward_pattern=[ForwardPattern.Pattern_1, ForwardPattern.Pattern_1],
params_modifiers=[modifier],
),
cache_config=db_cache_config,
)
Note: For single transformer with multiple block lists,
refresh_contextworks the same as standard models.
Registering Custom Implementations¶
After writing your custom enabler, register it in CUSTOM_DIT_ENABLERS in vllm_omni/diffusion/cache/cache_dit_backend.py:
CUSTOM_DIT_ENABLERS = {
"Wan22Pipeline": enable_cache_for_wan22,
"LongCatImagePipeline": enable_cache_for_longcat_image,
"YourCustomPipeline": enable_cache_for_your_model, # Add here
}
Testing¶
After adding cache-dit support, test with:
from vllm_omni import Omni
from vllm_omni.inputs.data import OmniDiffusionSamplingParams
# Test your custom model
omni = Omni(
model="your-model-name",
cache_backend="cache_dit",
cache_config={
"Fn_compute_blocks": 1,
"Bn_compute_blocks": 0,
"max_warmup_steps": 4,
"residual_diff_threshold": 0.24,
}
)
images = omni.generate(
"a beautiful landscape",
OmniDiffusionSamplingParams(num_inference_steps=50),
)
Verify:
- Cache is applied (check logs for "Cache-dit enabled successfully on xxx")
- Performance improvement (should be around 1.5x-2x faster)
- Image quality (compare with
cache_backend=None)
Troubleshooting¶
Issue: Cache not applied¶
Symptoms: No speedup observed, no cache-related log messages.
Causes & Solutions:
- Enabler not registered:
Problem: Pipeline name not in CUSTOM_DIT_ENABLERS registry.
Solution: Verify pipeline.__class__.__name__ matches the registry key and add your enabler to CUSTOM_DIT_ENABLERS.
Issue: Quality degradation¶
Symptoms: Generated images have artifacts or lower quality compared to non-cached inference.
Causes & Solutions:
- Cache parameters too aggressive:
Solution:
cache_config={
"residual_diff_threshold": 0.12, # Lower from 0.24 (try 0.12-0.18)
"max_warmup_steps": 6, # Increase from 4 (try 6-8)
"max_continuous_cached_steps": 2, # Reduce if higher
}
Check the user guide for cache_dit for more adjustable parameters.
Reference Implementations¶
Complete examples in the codebase:
| Model | Path | Pattern | Notes |
|---|---|---|---|
| Standard DiT | cache_dit_backend.py::enable_cache_for_dit | Default enabler | Single transformer, automatic |
| Wan2.2 | cache_dit_backend.py::enable_cache_for_wan22 | Single or dual-transformer | Auto-detects mode based on transformer_2 presence |
| LongCat | cache_dit_backend.py::enable_cache_for_longcat_image | Multi-block-list | Two block lists in one transformer |
| BAGEL | cache_dit_backend.py::enable_cache_for_bagel | Omni model | Complex architecture |
Summary¶
Adding cache-dit support:
- ✅ Check model type - Standard models work automatically, custom architectures need enablers
- ✅ Write enabler (if needed) - Use
BlockAdapterfor complex architectures - ✅ Register enabler (if needed) - Add to
CUSTOM_DIT_ENABLERSdictionary - ✅ Return refresh function (if needed) - Handle
num_inference_stepschanges - ✅ Test - Verify with
cache_backend="cache_dit"
For most models, the default enabler is sufficient. Only write custom enablers for complex architectures!