Skip to content

vllm_omni.diffusion.cache.magcache.strategy

MagCache strategy definitions for different model architectures.

This module provides model-specific strategies for MagCache, allowing easy extension to new models by implementing the MagCacheStrategy interface.

Architecture: - MagCacheStrategy: Abstract base class defining the strategy interface - FluxMagCacheStrategy: Strategy for Flux (dual-stream) models

Flux2MagCacheStrategy

Bases: FluxMagCacheStrategy

MagCache strategy for Flux2 model.

Flux2 shares the same dual-stream architecture as Flux, but may have different tensor shapes in some transformer blocks, requiring special handling in residual computation.

FLUX2_MAG_RATIOS class-attribute instance-attribute

FLUX2_MAG_RATIOS = tensor(
    [
        1.0,
        0.96528,
        1.11559,
        1.0565,
        1.00425,
        1.0805,
        0.98616,
        1.09289,
        1.03196,
        1.06679,
        1.03941,
        1.05375,
        1.03128,
        1.05349,
        1.01983,
        1.05535,
        1.0662,
        1.05748,
        1.00318,
        1.05222,
        1.04556,
        1.0506,
        1.05058,
        1.05219,
        1.02025,
        1.05052,
        1.04143,
        1.0498,
    ]
)

mag_ratios property

mag_ratios: Tensor

Return default mag_ratios for Flux2 model.

apply_residual

apply_residual(
    hidden_states: Tensor, residual: Tensor
) -> Tensor

Apply residual for Flux2.

For single-stream blocks: if shapes match, adds residual; otherwise returns input. For dual-stream blocks: applies decoder residual only.

apply_residual_tuple

apply_residual_tuple(
    hidden_states: Tensor,
    encoder_hidden_states: Tensor,
    residual: tuple[Tensor, Tensor],
) -> tuple[Tensor, Tensor]

Apply residual tuple for Flux2 with shape mismatch handling.

For Flux2 dual-stream blocks, handles the case where hidden_states and decoder_residual may have different shapes.

compute_residual

compute_residual(
    output: Tensor, head_input: Tensor
) -> Tensor | tuple[Tensor, Tensor]

Compute residual for Flux2 with dual-stream support.

For dual-stream blocks, computes residual for both encoder and decoder branches. For single-stream blocks: if shapes match, computes output - input; otherwise returns output.

register_block_metadata

register_block_metadata(
    block_class: type,
) -> TransformerBlockMetadata | None

Register Flux2-specific block metadata based on block type.

Flux2 has two block types with different output formats: - Dual-stream (Flux2TransformerBlock): returns (encoder_hidden_states, hidden_states) - Single-stream (Flux2SingleTransformerBlock): returns single tensor

FluxMagCacheStrategy

Bases: MagCacheStrategy

MagCache strategy for Flux (dual-stream) models.

Flux architecture: - transformer blocks (dual-stream): image tokens and text tokens processed independently with separate weights - single transformer blocks (single-stream): concatenated sequence (image + text tokens) shares the same group of weights - Final norm_out and proj_out layers

This strategy provides: - mag_ratios: Pre-computed magnitude ratios for Flux (28 steps) - compute_residual: Handles tuple output format - apply_residual_tuple: Handles decoder residual only

FLUX_MAG_RATIOS class-attribute instance-attribute

FLUX_MAG_RATIOS = tensor(
    [
        1.0,
        1.07313,
        1.21035,
        1.04432,
        1.06818,
        1.05547,
        1.0183,
        1.03405,
        1.02574,
        1.03042,
        1.02739,
        1.01955,
        1.01585,
        1.02439,
        1.01154,
        1.01377,
        1.00994,
        1.01444,
        1.00839,
        1.02269,
        1.0007,
        1.00714,
        1.00484,
        1.01381,
        1.00426,
        0.99764,
        1.00778,
        1.00233,
    ]
)

mag_ratios property

mag_ratios: Tensor

Return default mag_ratios for Flux model.

apply_residual_tuple

apply_residual_tuple(
    hidden_states: Tensor,
    encoder_hidden_states: Tensor,
    residual: tuple[Tensor, Tensor],
) -> tuple[Tensor, Tensor]

Apply residual tuple for Flux - only add decoder residual.

Flux has separate image and text processing, so the residual is only applied to the decoder (image) branch.

compute_residual

compute_residual(
    output: Tensor, head_input: Tensor
) -> Tensor

Compute residual for Flux output format (tuple or single tensor).

Flux single transformer blocks return a tuple, so we extract the decoder output (index 1) before computing residual.

nearest_interp staticmethod

nearest_interp(
    src_array: Tensor, target_length: int
) -> Tensor

Interpolate mag_ratios to target length using nearest neighbor.

MagCacheStrategy

Bases: ABC

Abstract base class for MagCache strategies.

Each model architecture requires a specific strategy to handle: - Residual computation (how to calculate the residual for caching) - Residual application (how to apply cached residual) - Model-specific magnitude ratios

Implement this class to add support for new model architectures.

mag_ratios abstractmethod property

mag_ratios: Tensor

Return the default mag_ratios tensor for this model.

This tensor defines caching ratios for each transformer block. Values should be calibrated for the specific model architecture.

Returns:

Type Description
Tensor

1D tensor of mag_ratios (one per transformer block).

apply_residual

apply_residual(
    hidden_states: Tensor, residual: Tensor
) -> Tensor

Apply cached residual to hidden states.

Default implementation: add residual to hidden_states. This works for most model architectures.

Parameters:

Name Type Description Default
hidden_states Tensor

Current hidden states.

required
residual Tensor

Cached residual to apply.

required

Returns:

Type Description
Tensor

Hidden states with residual added.

apply_residual_tuple

apply_residual_tuple(
    hidden_states: Tensor,
    encoder_hidden_states: Tensor,
    residual: tuple[Tensor, Tensor],
) -> tuple[Tensor, Tensor]

Apply cached residual tuple to both hidden_states and encoder_hidden_states.

Default implementation: add residuals separately. Override this method for models with specific residual application logic.

Parameters:

Name Type Description Default
hidden_states Tensor

Current hidden states.

required
encoder_hidden_states Tensor

Current encoder hidden states.

required
residual tuple[Tensor, Tensor]

Tuple of (hidden_states_residual, encoder_hidden_states_residual).

required

Returns:

Type Description
tuple[Tensor, Tensor]

Tuple of (hidden_states, encoder_hidden_states) with residuals applied.

compute_calibration_metrics

compute_calibration_metrics(
    current_residual: Tensor | tuple[Tensor, Tensor],
    previous_residual: Tensor
    | tuple[Tensor, Tensor]
    | None,
) -> tuple[float, float, float]

Compute calibration metrics for mag_ratios generation.

Default implementation computes norm ratios and cosine dissimilarity. Override this method for models with custom metric computation.

Parameters:

Name Type Description Default
current_residual Tensor | tuple[Tensor, Tensor]

Residual from the current step.

required
previous_residual Tensor | tuple[Tensor, Tensor] | None

Residual from the previous step (None for first step).

required

Returns:

Type Description
float

Tuple of (norm_ratio, norm_std, cos_dis):

float
  • norm_ratio: Mean ratio of current to previous residual norms
float
  • norm_std: Standard deviation of the norm ratios
tuple[float, float, float]
  • cos_dis: Mean cosine dissimilarity (1 - cosine_similarity)

compute_residual

compute_residual(
    output: Tensor, head_input: Tensor
) -> Tensor

Compute residual between block output and input.

Default implementation: output - head_input. Override this method for models with non-standard output formats.

Parameters:

Name Type Description Default
output Tensor

Output from transformer blocks.

required
head_input Tensor

Input to the first block.

required

Returns:

Type Description
Tensor

Residual tensor for caching.

get_calibration_metrics_names

get_calibration_metrics_names() -> tuple[str, str, str]

Return the names of calibration metrics for logging.

Returns:

Type Description
tuple[str, str, str]

Tuple of metric names in order: (norm_ratio_name, norm_std_name, cos_dis_name)

register_block_metadata

register_block_metadata(
    block_class: type,
) -> TransformerBlockMetadata | None

Register model-specific transformer block metadata.

Override this method to provide custom metadata for transformer blocks that have non-standard output formats (e.g., tuple returns).

Parameters:

Name Type Description Default
block_class type

The transformer block class to register.

required

Returns:

Type Description
TransformerBlockMetadata | None

TransformerBlockMetadata if custom registration is needed, None otherwise.

MagCacheStrategyRegistry

Registry for MagCache strategies by transformer type.

get classmethod

get(transformer_type: str) -> MagCacheStrategy

Get strategy for given transformer type.

get_if_exists classmethod

get_if_exists(
    transformer_type: str,
) -> MagCacheStrategy | None

Get strategy if exists, None otherwise.

register classmethod

register(name: str, strategy: MagCacheStrategy) -> None

Register a strategy with explicit name.

Parameters:

Name Type Description Default
name str

Transformer model type identifier (e.g., "FluxTransformer2DModel")

required
strategy MagCacheStrategy

MagCacheStrategy instance

required

get_strategy

get_strategy(transformer_cls_name: str) -> MagCacheStrategy

Get strategy function for given transformer class.

This function looks up the strategy based on the exact transformer_cls_name string, which should match the transformer type in the pipeline (i.e., pipeline.transformer.class.name).

Parameters:

Name Type Description Default
transformer_cls_name str

Transformer class name (e.g., "FluxTransformer2DModel") Must exactly match a registered strategy.

required

Returns:

Type Description
MagCacheStrategy

MagCacheStrategy instance for the model

Raises:

Type Description
ValueError

If model type not found in registry

register_strategy

register_strategy(
    transformer_cls_name: str, strategy: MagCacheStrategy
) -> None

Register a MagCache strategy for a model type.

This allows extending MagCache support to new models without modifying the core MagCache code.

Parameters:

Name Type Description Default
transformer_cls_name str

Transformer model type identifier (class name or type string) Must match pipeline.transformer.class.name

required
strategy MagCacheStrategy

MagCacheStrategy instance for this model type

required
Example

class MyModelMagCacheStrategy(MagCacheStrategy): ... @property ... def mag_ratios(self): ... return torch.tensor([...]) register_strategy("MyModelTransformer", MyModelMagCacheStrategy())