Skip to content

vllm_omni.diffusion.models.flux2.flux2_transformer

Flux2Attention

Bases: Module

add_kv_num_heads instance-attribute

add_kv_num_heads = num_kv_heads

add_kv_proj instance-attribute

add_kv_proj = QKVParallelLinear(
    hidden_size=added_kv_proj_dim,
    head_size=head_dim,
    total_num_heads=heads,
    bias=added_proj_bias,
    quant_config=quant_config,
    prefix=_join_prefix(prefix, "add_kv_proj"),
)

add_query_num_heads instance-attribute

add_query_num_heads = num_heads

added_kv_proj_dim instance-attribute

added_kv_proj_dim = added_kv_proj_dim

attn instance-attribute

attn = Attention(
    num_heads=query_num_heads,
    head_size=head_dim,
    softmax_scale=1.0 / head_dim**0.5,
    causal=False,
    num_kv_heads=kv_num_heads,
)

dropout instance-attribute

dropout = dropout

head_dim instance-attribute

head_dim = dim_head

heads instance-attribute

heads = (
    out_dim // dim_head if out_dim is not None else heads
)

inner_dim instance-attribute

inner_dim = (
    out_dim if out_dim is not None else dim_head * heads
)

kv_num_heads instance-attribute

kv_num_heads = num_kv_heads

norm_added_k instance-attribute

norm_added_k = RMSNorm(dim_head, eps=eps)

norm_added_q instance-attribute

norm_added_q = RMSNorm(dim_head, eps=eps)

norm_k instance-attribute

norm_k = RMSNorm(dim_head, eps=eps)

norm_q instance-attribute

norm_q = RMSNorm(dim_head, eps=eps)

out_dim instance-attribute

out_dim = out_dim if out_dim is not None else query_dim

parallel_config instance-attribute

parallel_config = parallel_config

query_dim instance-attribute

query_dim = query_dim

query_num_heads instance-attribute

query_num_heads = num_heads

rope instance-attribute

rope = RotaryEmbedding(is_neox_style=False)

to_add_out instance-attribute

to_add_out = RowParallelLinear(
    inner_dim,
    query_dim,
    bias=out_bias,
    input_is_parallel=True,
    return_bias=False,
    quant_config=quant_config,
    prefix=_join_prefix(prefix, "to_add_out"),
)

to_out instance-attribute

to_out = ModuleList(
    [
        RowParallelLinear(
            inner_dim,
            out_dim,
            bias=out_bias,
            input_is_parallel=True,
            return_bias=False,
            quant_config=quant_config,
            prefix=_join_prefix(prefix, "to_out.0"),
        ),
        Dropout(dropout),
    ]
)

to_qkv instance-attribute

to_qkv = QKVParallelLinear(
    hidden_size=query_dim,
    head_size=head_dim,
    total_num_heads=heads,
    bias=bias,
    quant_config=quant_config,
    prefix=_join_prefix(prefix, "to_qkv"),
)

forward

forward(
    hidden_states: Tensor,
    encoder_hidden_states: Tensor | None = None,
    attention_mask: Tensor | None = None,
    image_rotary_emb: tuple[Tensor, Tensor] | None = None,
    **kwargs,
) -> Tensor | tuple[Tensor, Tensor]

Flux2FeedForward

Bases: Module

act_fn instance-attribute

act_fn = Flux2SwiGLU()

linear_in instance-attribute

linear_in = MergedColumnParallelLinear(
    dim,
    [inner_dim, inner_dim],
    bias=bias,
    return_bias=False,
    quant_config=quant_config,
    prefix=_join_prefix(prefix, "linear_in"),
)

linear_out instance-attribute

linear_out = RowParallelLinear(
    inner_dim,
    dim_out,
    bias=bias,
    input_is_parallel=True,
    return_bias=False,
    quant_config=quant_config,
    prefix=_join_prefix(prefix, "linear_out"),
)

forward

forward(x: Tensor) -> Tensor

Flux2Modulation

Bases: Module

act_fn instance-attribute

act_fn = SiLU()

linear instance-attribute

linear = Linear(dim, dim * 3 * mod_param_sets, bias=bias)

mod_param_sets instance-attribute

mod_param_sets = mod_param_sets

forward

forward(
    temb: Tensor,
) -> tuple[tuple[Tensor, Tensor, Tensor], ...]

Flux2ParallelSelfAttention

Bases: Module

Parallel attention block that fuses QKV projections with MLP input projections.

attn instance-attribute

attn = Attention(
    num_heads=heads,
    head_size=head_dim,
    softmax_scale=1.0 / head_dim**0.5,
    causal=False,
)

dropout instance-attribute

dropout = dropout

head_dim instance-attribute

head_dim = dim_head

heads instance-attribute

heads = (
    out_dim // dim_head if out_dim is not None else heads
)

inner_dim instance-attribute

inner_dim = (
    out_dim if out_dim is not None else dim_head * heads
)

mlp_act_fn instance-attribute

mlp_act_fn = Flux2SwiGLU()

mlp_hidden_dim instance-attribute

mlp_hidden_dim = int(query_dim * mlp_ratio)

mlp_mult_factor instance-attribute

mlp_mult_factor = mlp_mult_factor

mlp_ratio instance-attribute

mlp_ratio = mlp_ratio

norm_k instance-attribute

norm_k = RMSNorm(dim_head, eps=eps)

norm_q instance-attribute

norm_q = RMSNorm(dim_head, eps=eps)

out_dim instance-attribute

out_dim = out_dim if out_dim is not None else query_dim

parallel_config instance-attribute

parallel_config = parallel_config

query_dim instance-attribute

query_dim = query_dim

rope instance-attribute

rope = RotaryEmbedding(is_neox_style=False)

to_out instance-attribute

to_out = ColumnParallelLinear(
    inner_dim + mlp_hidden_dim,
    out_dim,
    bias=out_bias,
    gather_output=True,
    quant_config=quant_config,
    prefix=_join_prefix(prefix, "to_out"),
)

to_qkv_mlp_proj instance-attribute

to_qkv_mlp_proj = ColumnParallelLinear(
    query_dim,
    inner_dim * 3 + mlp_hidden_dim * mlp_mult_factor,
    bias=bias,
    gather_output=True,
    quant_config=quant_config,
    prefix=_join_prefix(prefix, "to_qkv_mlp_proj"),
)

forward

forward(
    hidden_states: Tensor,
    attention_mask: Tensor | None = None,
    image_rotary_emb: tuple[Tensor, Tensor] | None = None,
    **kwargs,
) -> Tensor

Flux2PosEmbed

Bases: Module

axes_dim instance-attribute

axes_dim = axes_dim

theta instance-attribute

theta = theta

forward

forward(ids: Tensor) -> tuple[Tensor, Tensor]

Flux2RopePrepare

Bases: Module

Prepares RoPE embeddings for sequence parallel.

This module encapsulates the RoPE computation for Flux.2-dev. For dual-stream attention, text components (outputs 0, 1) are replicated across SP ranks, while image components (outputs 2, 3) are sharded.

NOTE: The hidden_states projection is handled separately in forward() so that _sp_plan can shard it at the root level.

pos_embed instance-attribute

pos_embed = pos_embed

forward

forward(
    img_ids: Tensor, txt_ids: Tensor
) -> tuple[Tensor, Tensor, Tensor, Tensor]

Compute RoPE embeddings for text and image sequences.

Parameters:

Name Type Description Default
img_ids Tensor

Image position IDs (img_seq_len, n_axes)

required
txt_ids Tensor

Text position IDs (txt_seq_len, n_axes)

required

Returns:

Type Description
Tensor

Tuple of cosine / sine components for text & image

Tensor

in the order: (txt_cos, txt_sin, img_cos, img_sin)

NOTE: careful about output orders if this is refactored in the future; we need to match the _sp_plan indices, since text components (0 & 1) need to be replicated across SP ranks, while image components (2 & 3) must be sharded.

Flux2SingleTransformerBlock

Bases: Module

attn instance-attribute

attn = Flux2ParallelSelfAttention(
    parallel_config=parallel_config,
    query_dim=dim,
    dim_head=attention_head_dim,
    heads=num_attention_heads,
    out_dim=dim,
    bias=bias,
    out_bias=bias,
    eps=eps,
    mlp_ratio=mlp_ratio,
    mlp_mult_factor=2,
    quant_config=quant_config,
    prefix=_join_prefix(prefix, "attn"),
)

norm instance-attribute

norm = LayerNorm(dim, elementwise_affine=False, eps=eps)

forward

forward(
    hidden_states: Tensor,
    encoder_hidden_states: Tensor | None,
    temb_mod_params: tuple[Tensor, Tensor, Tensor],
    image_rotary_emb: tuple[Tensor, Tensor] | None = None,
    joint_attention_kwargs: dict[str, Any] | None = None,
    split_hidden_states: bool = False,
    text_seq_len: int | None = None,
) -> Tensor | tuple[Tensor, Tensor]

Forward pass for Flux2SingleTransformerBlock with SP support.

In SP mode: image hidden_states is chunked (B, img_len/SP, D), text encoder_hidden_states is full (B, txt_len, D). The block concatenates them for joint attention.

Flux2SwiGLU

Bases: Module

SwiGLU activation used by Flux2.

gate_fn instance-attribute

gate_fn = SiLU()

forward

forward(x: Tensor) -> Tensor

Flux2TimestepGuidanceEmbeddings

Bases: Module

guidance_embedder instance-attribute

guidance_embedder = TimestepEmbedding(
    in_channels=in_channels,
    time_embed_dim=embedding_dim,
    sample_proj_bias=bias,
)

time_proj instance-attribute

time_proj = Timesteps(
    num_channels=in_channels,
    flip_sin_to_cos=True,
    downscale_freq_shift=0,
)

timestep_embedder instance-attribute

timestep_embedder = TimestepEmbedding(
    in_channels=in_channels,
    time_embed_dim=embedding_dim,
    sample_proj_bias=bias,
)

forward

forward(
    timestep: Tensor, guidance: Tensor | None
) -> Tensor

Flux2Transformer2DModel

Bases: Module

The Transformer model introduced in Flux 2.

Supports Sequence Parallelism (Ulysses and Ring) when configured via OmniDiffusionConfig.

config instance-attribute

config = SimpleNamespace(
    patch_size=patch_size,
    in_channels=in_channels,
    out_channels=out_channels,
    num_layers=num_layers,
    num_single_layers=num_single_layers,
    attention_head_dim=attention_head_dim,
    num_attention_heads=num_attention_heads,
    joint_attention_dim=joint_attention_dim,
    timestep_guidance_channels=timestep_guidance_channels,
    mlp_ratio=mlp_ratio,
    axes_dims_rope=axes_dims_rope,
    rope_theta=rope_theta,
    eps=eps,
    guidance_embeds=guidance_embeds,
)

context_embedder instance-attribute

context_embedder = Linear(
    joint_attention_dim, inner_dim, bias=False
)

double_stream_modulation_img instance-attribute

double_stream_modulation_img = Flux2Modulation(
    inner_dim, mod_param_sets=2, bias=False
)

double_stream_modulation_txt instance-attribute

double_stream_modulation_txt = Flux2Modulation(
    inner_dim, mod_param_sets=2, bias=False
)

dtype property

dtype: dtype

guidance_embeds instance-attribute

guidance_embeds = guidance_embeds

inner_dim instance-attribute

inner_dim = num_attention_heads * attention_head_dim

norm_out instance-attribute

norm_out = AdaLayerNormContinuous(
    inner_dim,
    inner_dim,
    elementwise_affine=False,
    eps=eps,
    bias=False,
)

out_channels instance-attribute

out_channels = out_channels or in_channels

parallel_config instance-attribute

parallel_config = parallel_config

pos_embed instance-attribute

pos_embed = Flux2PosEmbed(
    theta=rope_theta, axes_dim=axes_dims_rope
)

proj_out instance-attribute

proj_out = Linear(
    inner_dim,
    patch_size * patch_size * out_channels,
    bias=False,
)

rope_prepare instance-attribute

rope_prepare = Flux2RopePrepare(pos_embed)

single_stream_modulation instance-attribute

single_stream_modulation = Flux2Modulation(
    inner_dim, mod_param_sets=1, bias=False
)

single_transformer_blocks instance-attribute

single_transformer_blocks = ModuleList(
    [
        (
            Flux2SingleTransformerBlock(
                parallel_config=parallel_config,
                dim=inner_dim,
                num_attention_heads=num_attention_heads,
                attention_head_dim=attention_head_dim,
                mlp_ratio=mlp_ratio,
                eps=eps,
                bias=False,
                quant_config=quant_config,
                prefix=f"single_transformer_blocks.{i}",
            )
        )
        for i in (range(num_single_layers))
    ]
)

stacked_params_mapping instance-attribute

stacked_params_mapping = None

time_guidance_embed instance-attribute

time_guidance_embed = Flux2TimestepGuidanceEmbeddings(
    in_channels=timestep_guidance_channels,
    embedding_dim=inner_dim,
    bias=False,
    guidance_embeds=guidance_embeds,
)

transformer_blocks instance-attribute

transformer_blocks = ModuleList(
    [
        (
            Flux2TransformerBlock(
                parallel_config=parallel_config,
                dim=inner_dim,
                num_attention_heads=num_attention_heads,
                attention_head_dim=attention_head_dim,
                mlp_ratio=mlp_ratio,
                eps=eps,
                bias=False,
                quant_config=quant_config,
                prefix=f"transformer_blocks.{i}",
            )
        )
        for i in (range(num_layers))
    ]
)

x_embedder instance-attribute

x_embedder = Linear(in_channels, inner_dim, bias=False)

forward

forward(
    hidden_states: Tensor,
    encoder_hidden_states: Tensor = None,
    timestep: LongTensor = None,
    img_ids: Tensor = None,
    txt_ids: Tensor = None,
    guidance: Tensor | None = None,
    joint_attention_kwargs: dict[str, Any] | None = None,
    return_dict: bool = True,
) -> Tensor | Transformer2DModelOutput

load_weights

load_weights(
    weights: Iterable[tuple[str, Tensor]],
) -> set[str]

Flux2TransformerBlock

Bases: Module

attn instance-attribute

attn = Flux2Attention(
    parallel_config=parallel_config,
    query_dim=dim,
    added_kv_proj_dim=dim,
    dim_head=attention_head_dim,
    heads=num_attention_heads,
    out_dim=dim,
    bias=bias,
    added_proj_bias=bias,
    out_bias=bias,
    eps=eps,
    quant_config=quant_config,
    prefix=_join_prefix(prefix, "attn"),
)

ff instance-attribute

ff = Flux2FeedForward(
    dim=dim,
    dim_out=dim,
    mult=mlp_ratio,
    bias=bias,
    quant_config=quant_config,
    prefix=_join_prefix(prefix, "ff"),
)

ff_context instance-attribute

ff_context = Flux2FeedForward(
    dim=dim,
    dim_out=dim,
    mult=mlp_ratio,
    bias=bias,
    quant_config=quant_config,
    prefix=_join_prefix(prefix, "ff_context"),
)

mlp_hidden_dim instance-attribute

mlp_hidden_dim = int(dim * mlp_ratio)

norm1 instance-attribute

norm1 = LayerNorm(dim, elementwise_affine=False, eps=eps)

norm1_context instance-attribute

norm1_context = LayerNorm(
    dim, elementwise_affine=False, eps=eps
)

norm2 instance-attribute

norm2 = LayerNorm(dim, elementwise_affine=False, eps=eps)

norm2_context instance-attribute

norm2_context = LayerNorm(
    dim, elementwise_affine=False, eps=eps
)

forward

forward(
    hidden_states: Tensor,
    encoder_hidden_states: Tensor,
    temb_mod_params_img: tuple[
        tuple[Tensor, Tensor, Tensor], ...
    ],
    temb_mod_params_txt: tuple[
        tuple[Tensor, Tensor, Tensor], ...
    ],
    image_rotary_emb: tuple[Tensor, Tensor] | None = None,
    joint_attention_kwargs: dict[str, Any] | None = None,
) -> tuple[Tensor, Tensor]