vllm_omni.diffusion.models.flux2.flux2_transformer ¶
Flux2Attention ¶
Bases: Module
add_kv_proj instance-attribute ¶
add_kv_proj = QKVParallelLinear(
hidden_size=added_kv_proj_dim,
head_size=head_dim,
total_num_heads=heads,
bias=added_proj_bias,
quant_config=quant_config,
prefix=_join_prefix(prefix, "add_kv_proj"),
)
attn instance-attribute ¶
attn = Attention(
num_heads=query_num_heads,
head_size=head_dim,
softmax_scale=1.0 / head_dim**0.5,
causal=False,
num_kv_heads=kv_num_heads,
)
to_add_out instance-attribute ¶
to_add_out = RowParallelLinear(
inner_dim,
query_dim,
bias=out_bias,
input_is_parallel=True,
return_bias=False,
quant_config=quant_config,
prefix=_join_prefix(prefix, "to_add_out"),
)
to_out instance-attribute ¶
to_out = ModuleList(
[
RowParallelLinear(
inner_dim,
out_dim,
bias=out_bias,
input_is_parallel=True,
return_bias=False,
quant_config=quant_config,
prefix=_join_prefix(prefix, "to_out.0"),
),
Dropout(dropout),
]
)
to_qkv instance-attribute ¶
to_qkv = QKVParallelLinear(
hidden_size=query_dim,
head_size=head_dim,
total_num_heads=heads,
bias=bias,
quant_config=quant_config,
prefix=_join_prefix(prefix, "to_qkv"),
)
Flux2FeedForward ¶
Bases: Module
linear_in instance-attribute ¶
linear_in = MergedColumnParallelLinear(
dim,
[inner_dim, inner_dim],
bias=bias,
return_bias=False,
quant_config=quant_config,
prefix=_join_prefix(prefix, "linear_in"),
)
linear_out instance-attribute ¶
linear_out = RowParallelLinear(
inner_dim,
dim_out,
bias=bias,
input_is_parallel=True,
return_bias=False,
quant_config=quant_config,
prefix=_join_prefix(prefix, "linear_out"),
)
Flux2Modulation ¶
Flux2ParallelSelfAttention ¶
Bases: Module
Parallel attention block that fuses QKV projections with MLP input projections.
attn instance-attribute ¶
attn = Attention(
num_heads=heads,
head_size=head_dim,
softmax_scale=1.0 / head_dim**0.5,
causal=False,
)
to_out instance-attribute ¶
to_out = ColumnParallelLinear(
inner_dim + mlp_hidden_dim,
out_dim,
bias=out_bias,
gather_output=True,
quant_config=quant_config,
prefix=_join_prefix(prefix, "to_out"),
)
to_qkv_mlp_proj instance-attribute ¶
to_qkv_mlp_proj = ColumnParallelLinear(
query_dim,
inner_dim * 3 + mlp_hidden_dim * mlp_mult_factor,
bias=bias,
gather_output=True,
quant_config=quant_config,
prefix=_join_prefix(prefix, "to_qkv_mlp_proj"),
)
Flux2PosEmbed ¶
Flux2RopePrepare ¶
Bases: Module
Prepares RoPE embeddings for sequence parallel.
This module encapsulates the RoPE computation for Flux.2-dev. For dual-stream attention, text components (outputs 0, 1) are replicated across SP ranks, while image components (outputs 2, 3) are sharded.
NOTE: The hidden_states projection is handled separately in forward() so that _sp_plan can shard it at the root level.
forward ¶
forward(
img_ids: Tensor, txt_ids: Tensor
) -> tuple[Tensor, Tensor, Tensor, Tensor]
Compute RoPE embeddings for text and image sequences.
Parameters:
| Name | Type | Description | Default |
|---|---|---|---|
img_ids | Tensor | Image position IDs (img_seq_len, n_axes) | required |
txt_ids | Tensor | Text position IDs (txt_seq_len, n_axes) | required |
Returns:
| Type | Description |
|---|---|
Tensor | Tuple of cosine / sine components for text & image |
Tensor | in the order: (txt_cos, txt_sin, img_cos, img_sin) |
NOTE: careful about output orders if this is refactored in the future; we need to match the _sp_plan indices, since text components (0 & 1) need to be replicated across SP ranks, while image components (2 & 3) must be sharded.
Flux2SingleTransformerBlock ¶
Bases: Module
attn instance-attribute ¶
attn = Flux2ParallelSelfAttention(
parallel_config=parallel_config,
query_dim=dim,
dim_head=attention_head_dim,
heads=num_attention_heads,
out_dim=dim,
bias=bias,
out_bias=bias,
eps=eps,
mlp_ratio=mlp_ratio,
mlp_mult_factor=2,
quant_config=quant_config,
prefix=_join_prefix(prefix, "attn"),
)
forward ¶
forward(
hidden_states: Tensor,
encoder_hidden_states: Tensor | None,
temb_mod_params: tuple[Tensor, Tensor, Tensor],
image_rotary_emb: tuple[Tensor, Tensor] | None = None,
joint_attention_kwargs: dict[str, Any] | None = None,
split_hidden_states: bool = False,
text_seq_len: int | None = None,
) -> Tensor | tuple[Tensor, Tensor]
Forward pass for Flux2SingleTransformerBlock with SP support.
In SP mode: image hidden_states is chunked (B, img_len/SP, D), text encoder_hidden_states is full (B, txt_len, D). The block concatenates them for joint attention.
Flux2SwiGLU ¶
Flux2TimestepGuidanceEmbeddings ¶
Bases: Module
guidance_embedder instance-attribute ¶
guidance_embedder = TimestepEmbedding(
in_channels=in_channels,
time_embed_dim=embedding_dim,
sample_proj_bias=bias,
)
time_proj instance-attribute ¶
timestep_embedder instance-attribute ¶
timestep_embedder = TimestepEmbedding(
in_channels=in_channels,
time_embed_dim=embedding_dim,
sample_proj_bias=bias,
)
Flux2Transformer2DModel ¶
Bases: Module
The Transformer model introduced in Flux 2.
Supports Sequence Parallelism (Ulysses and Ring) when configured via OmniDiffusionConfig.
config instance-attribute ¶
config = SimpleNamespace(
patch_size=patch_size,
in_channels=in_channels,
out_channels=out_channels,
num_layers=num_layers,
num_single_layers=num_single_layers,
attention_head_dim=attention_head_dim,
num_attention_heads=num_attention_heads,
joint_attention_dim=joint_attention_dim,
timestep_guidance_channels=timestep_guidance_channels,
mlp_ratio=mlp_ratio,
axes_dims_rope=axes_dims_rope,
rope_theta=rope_theta,
eps=eps,
guidance_embeds=guidance_embeds,
)
context_embedder instance-attribute ¶
double_stream_modulation_img instance-attribute ¶
double_stream_modulation_img = Flux2Modulation(
inner_dim, mod_param_sets=2, bias=False
)
double_stream_modulation_txt instance-attribute ¶
double_stream_modulation_txt = Flux2Modulation(
inner_dim, mod_param_sets=2, bias=False
)
norm_out instance-attribute ¶
norm_out = AdaLayerNormContinuous(
inner_dim,
inner_dim,
elementwise_affine=False,
eps=eps,
bias=False,
)
pos_embed instance-attribute ¶
pos_embed = Flux2PosEmbed(
theta=rope_theta, axes_dim=axes_dims_rope
)
proj_out instance-attribute ¶
single_stream_modulation instance-attribute ¶
single_stream_modulation = Flux2Modulation(
inner_dim, mod_param_sets=1, bias=False
)
single_transformer_blocks instance-attribute ¶
single_transformer_blocks = ModuleList(
[
(
Flux2SingleTransformerBlock(
parallel_config=parallel_config,
dim=inner_dim,
num_attention_heads=num_attention_heads,
attention_head_dim=attention_head_dim,
mlp_ratio=mlp_ratio,
eps=eps,
bias=False,
quant_config=quant_config,
prefix=f"single_transformer_blocks.{i}",
)
)
for i in (range(num_single_layers))
]
)
time_guidance_embed instance-attribute ¶
time_guidance_embed = Flux2TimestepGuidanceEmbeddings(
in_channels=timestep_guidance_channels,
embedding_dim=inner_dim,
bias=False,
guidance_embeds=guidance_embeds,
)
transformer_blocks instance-attribute ¶
transformer_blocks = ModuleList(
[
(
Flux2TransformerBlock(
parallel_config=parallel_config,
dim=inner_dim,
num_attention_heads=num_attention_heads,
attention_head_dim=attention_head_dim,
mlp_ratio=mlp_ratio,
eps=eps,
bias=False,
quant_config=quant_config,
prefix=f"transformer_blocks.{i}",
)
)
for i in (range(num_layers))
]
)
forward ¶
forward(
hidden_states: Tensor,
encoder_hidden_states: Tensor = None,
timestep: LongTensor = None,
img_ids: Tensor = None,
txt_ids: Tensor = None,
guidance: Tensor | None = None,
joint_attention_kwargs: dict[str, Any] | None = None,
return_dict: bool = True,
) -> Tensor | Transformer2DModelOutput
Flux2TransformerBlock ¶
Bases: Module
attn instance-attribute ¶
attn = Flux2Attention(
parallel_config=parallel_config,
query_dim=dim,
added_kv_proj_dim=dim,
dim_head=attention_head_dim,
heads=num_attention_heads,
out_dim=dim,
bias=bias,
added_proj_bias=bias,
out_bias=bias,
eps=eps,
quant_config=quant_config,
prefix=_join_prefix(prefix, "attn"),
)
ff instance-attribute ¶
ff = Flux2FeedForward(
dim=dim,
dim_out=dim,
mult=mlp_ratio,
bias=bias,
quant_config=quant_config,
prefix=_join_prefix(prefix, "ff"),
)
ff_context instance-attribute ¶
ff_context = Flux2FeedForward(
dim=dim,
dim_out=dim,
mult=mlp_ratio,
bias=bias,
quant_config=quant_config,
prefix=_join_prefix(prefix, "ff_context"),
)
norm1_context instance-attribute ¶
norm2_context instance-attribute ¶
forward ¶
forward(
hidden_states: Tensor,
encoder_hidden_states: Tensor,
temb_mod_params_img: tuple[
tuple[Tensor, Tensor, Tensor], ...
],
temb_mod_params_txt: tuple[
tuple[Tensor, Tensor, Tensor], ...
],
image_rotary_emb: tuple[Tensor, Tensor] | None = None,
joint_attention_kwargs: dict[str, Any] | None = None,
) -> tuple[Tensor, Tensor]