Skip to content

vllm_omni.diffusion.models.flux.flux_transformer

logger module-attribute

logger = init_logger(__name__)

ColumnParallelApproxGELU

Bases: Module

approximate instance-attribute

approximate = approximate

proj instance-attribute

proj = ColumnParallelLinear(
    dim_in,
    dim_out,
    bias=bias,
    gather_output=False,
    return_bias=False,
    quant_config=quant_config,
    prefix=f"{prefix}.proj",
)

forward

forward(x: Tensor) -> Tensor

FeedForward

Bases: Module

net instance-attribute

net = ModuleList(layers)

forward

forward(hidden_states: Tensor) -> Tensor

FluxAttention

Bases: Module

add_kv_proj instance-attribute

add_kv_proj = QKVParallelLinear(
    hidden_size=added_kv_proj_dim,
    head_size=head_dim,
    total_num_heads=heads,
    bias=added_proj_bias,
    quant_config=quant_config,
    prefix=f"{prefix}.add_kv_proj",
)

added_kv_proj_dim instance-attribute

added_kv_proj_dim = added_kv_proj_dim

added_proj_bias instance-attribute

added_proj_bias = added_proj_bias

attn instance-attribute

attn = Attention(
    num_heads=num_heads,
    head_size=head_dim,
    softmax_scale=1.0 / head_dim**0.5,
    causal=False,
    num_kv_heads=num_kv_heads,
)

context_pre_only instance-attribute

context_pre_only = context_pre_only

dropout instance-attribute

dropout = dropout

head_dim instance-attribute

head_dim = dim_head

heads instance-attribute

heads = (
    out_dim // dim_head if out_dim is not None else heads
)

inner_dim instance-attribute

inner_dim = (
    out_dim if out_dim is not None else dim_head * heads
)

norm_added_k instance-attribute

norm_added_k = RMSNorm(dim_head, eps=eps)

norm_added_q instance-attribute

norm_added_q = RMSNorm(dim_head, eps=eps)

norm_k instance-attribute

norm_k = RMSNorm(dim_head, eps=eps)

norm_q instance-attribute

norm_q = RMSNorm(dim_head, eps=eps)

out_dim instance-attribute

out_dim = out_dim if out_dim is not None else query_dim

pre_only instance-attribute

pre_only = pre_only

query_dim instance-attribute

query_dim = query_dim

rope instance-attribute

rope = RotaryEmbedding(is_neox_style=False)

to_add_out instance-attribute

to_add_out = RowParallelLinear(
    inner_dim,
    query_dim,
    bias=out_bias,
    input_is_parallel=True,
    return_bias=False,
    quant_config=quant_config,
    prefix=f"{prefix}.to_add_out",
)

to_out instance-attribute

to_out = ModuleList(
    [
        RowParallelLinear(
            inner_dim,
            out_dim,
            bias=out_bias,
            input_is_parallel=True,
            return_bias=False,
            quant_config=quant_config,
            prefix=f"{prefix}.to_out.0",
        ),
        Dropout(dropout),
    ]
)

to_qkv instance-attribute

to_qkv = QKVParallelLinear(
    hidden_size=query_dim,
    head_size=head_dim,
    total_num_heads=heads,
    bias=bias,
    quant_config=quant_config,
    prefix=f"{prefix}.to_qkv",
)

use_bias instance-attribute

use_bias = bias

forward

forward(
    hidden_states: Tensor,
    encoder_hidden_states: Tensor | None = None,
    image_rotary_emb: tuple[Tensor, Tensor] | None = None,
    attention_mask: Tensor | None = None,
    **kwargs,
) -> Tensor | tuple[Tensor, Tensor]

FluxKontextTransformer2DModel

Bases: FluxTransformer2DModel

forward

forward(
    hidden_states: Tensor,
    encoder_hidden_states: Tensor = None,
    pooled_projections: Tensor = None,
    timestep: LongTensor = None,
    img_ids: Tensor = None,
    txt_ids: Tensor = None,
    guidance: Tensor | None = None,
    joint_attention_kwargs: dict[str, Any] | None = None,
    return_dict: bool = True,
) -> Tensor | Transformer2DModelOutput

FluxPosEmbed

Bases: Module

axes_dim instance-attribute

axes_dim = axes_dim

theta instance-attribute

theta = theta

forward

forward(ids: Tensor) -> Tensor

FluxSingleTransformerBlock

Bases: Module

act_mlp instance-attribute

act_mlp = GELU(approximate='tanh')

attn instance-attribute

attn = FluxAttention(
    query_dim=dim,
    dim_head=attention_head_dim,
    heads=num_attention_heads,
    out_dim=dim,
    bias=True,
    eps=1e-06,
    pre_only=True,
    quant_config=quant_config,
    prefix=f"{prefix}.attn",
)

mlp_hidden_dim instance-attribute

mlp_hidden_dim = int(dim * mlp_ratio)

norm instance-attribute

norm = AdaLayerNormZeroSingle(
    dim,
    quant_config=_safe_quant_config(quant_config),
    prefix=f"{prefix}.norm",
)

proj_mlp instance-attribute

proj_mlp = ReplicatedLinear(
    dim,
    mlp_hidden_dim,
    bias=True,
    return_bias=False,
    quant_config=quant_config,
    prefix=f"{prefix}.proj_mlp",
)

proj_out instance-attribute

proj_out = ReplicatedLinear(
    dim + mlp_hidden_dim,
    dim,
    bias=True,
    return_bias=False,
    quant_config=quant_config,
    prefix=f"{prefix}.proj_out",
)

forward

forward(
    hidden_states: Tensor,
    encoder_hidden_states: Tensor,
    temb: Tensor,
    image_rotary_emb: tuple[Tensor, Tensor] | None = None,
    joint_attention_kwargs: dict[str, Any] | None = None,
) -> tuple[Tensor, Tensor]

FluxTransformer2DModel

Bases: Module

The Transformer model introduced in Flux.

Parameters:

Name Type Description Default
od_config `OmniDiffusionConfig`

The configuration for the model.

None
patch_size `int`, defaults to `1`

Patch size to turn the input data into small patches.

1
in_channels `int`, defaults to `64`

The number of channels in the input.

64
out_channels `int`, *optional*, defaults to `None`

The number of channels in the output. If not specified, it defaults to in_channels.

None
num_layers `int`, defaults to `19`

The number of layers of dual stream DiT blocks to use.

19
num_single_layers `int`, defaults to `38`

The number of layers of single stream DiT blocks to use.

38
attention_head_dim `int`, defaults to `128`

The number of dimensions to use for each attention head.

128
num_attention_heads `int`, defaults to `24`

The number of attention heads to use.

24
joint_attention_dim `int`, defaults to `4096`

The number of dimensions to use for the joint attention (embedding/channel dimension of encoder_hidden_states).

4096
pooled_projection_dim `int`, defaults to `768`

The number of dimensions to use for the pooled projection.

768
guidance_embeds `bool`, defaults to `False`

Whether to use guidance embeddings for guidance-distilled variant of the model.

True
axes_dims_rope `Tuple[int]`, defaults to `(16, 56, 56)`

The dimensions to use for the rotary positional embeddings.

(16, 56, 56)

context_embedder instance-attribute

context_embedder = Linear(joint_attention_dim, inner_dim)

guidance_embeds instance-attribute

guidance_embeds = guidance_embeds

in_channels instance-attribute

in_channels = in_channels

inner_dim instance-attribute

inner_dim = num_attention_heads * attention_head_dim

norm_out instance-attribute

norm_out = AdaLayerNormContinuous(
    inner_dim,
    inner_dim,
    elementwise_affine=False,
    eps=1e-06,
    quant_config=_safe_quant_config(quant_config),
    prefix="norm_out",
)

out_channels instance-attribute

out_channels = out_channels or in_channels

packed_modules_mapping class-attribute instance-attribute

packed_modules_mapping = {
    "to_qkv": ["to_q", "to_k", "to_v"],
    "add_kv_proj": [
        "add_q_proj",
        "add_k_proj",
        "add_v_proj",
    ],
}

parallel_config instance-attribute

parallel_config = parallel_config

pos_embed instance-attribute

pos_embed = FluxPosEmbed(
    theta=theta, axes_dim=axes_dims_rope
)

proj_out instance-attribute

proj_out = Linear(
    inner_dim,
    patch_size * patch_size * out_channels,
    bias=True,
)

single_transformer_blocks instance-attribute

single_transformer_blocks = ModuleList(
    [
        (
            FluxSingleTransformerBlock(
                dim=inner_dim,
                num_attention_heads=num_attention_heads,
                attention_head_dim=attention_head_dim,
                quant_config=quant_config,
                prefix=f"single_transformer_blocks.{i}",
            )
        )
        for i in (range(num_single_layers))
    ]
)

time_text_embed instance-attribute

time_text_embed = text_time_guidance_cls(
    embedding_dim=inner_dim,
    pooled_projection_dim=pooled_projection_dim,
)

transformer_blocks instance-attribute

transformer_blocks = ModuleList(
    [
        (
            FluxTransformerBlock(
                dim=inner_dim,
                num_attention_heads=num_attention_heads,
                attention_head_dim=attention_head_dim,
                quant_config=_safe_quant_config(
                    quant_config
                ),
                prefix=f"transformer_blocks.{i}",
            )
        )
        for i in (range(num_layers))
    ]
)

x_embedder instance-attribute

x_embedder = Linear(in_channels, inner_dim)

forward

forward(
    hidden_states: Tensor,
    encoder_hidden_states: Tensor = None,
    pooled_projections: Tensor = None,
    timestep: LongTensor = None,
    img_ids: Tensor = None,
    txt_ids: Tensor = None,
    guidance: Tensor | None = None,
    joint_attention_kwargs: dict[str, Any] | None = None,
    return_dict: bool = True,
) -> Tensor | Transformer2DModelOutput

The [FluxTransformer2DModel] forward method.

Parameters:

Name Type Description Default
hidden_states `torch.Tensor` of shape `(batch_size, image_sequence_length, in_channels)`

Input hidden_states.

required
encoder_hidden_states `torch.Tensor` of shape `(batch_size, text_sequence_length, joint_attention_dim)`

Conditional embeddings (embeddings computed from the input conditions such as prompts) to use.

None
pooled_projections `torch.Tensor` of shape `(batch_size, projection_dim)`

Embeddings projected from the embeddings of input conditions.

None
timestep `torch.LongTensor`

Used to indicate denoising step.

None
img_ids Tensor

(torch.Tensor): The position ids for image tokens.

None
txt_ids `torch.Tensor`

The position ids for text tokens.

None
guidance `torch.Tensor`

Guidance embeddings for guidance-distilled variant of the model.

None
joint_attention_kwargs `dict`, *optional*

A kwargs dictionary that if specified is passed along to the AttentionProcessor as defined under self.processor in diffusers.models.attention_processor.

None
return_dict `bool`, *optional*, defaults to `True`

Whether or not to return a [~models.transformer_2d.Transformer2DModelOutput] instead of a plain tuple.

True

Returns:

Type Description
Tensor | Transformer2DModelOutput

If return_dict is True, an [~models.transformer_2d.Transformer2DModelOutput] is returned, otherwise a

Tensor | Transformer2DModelOutput

tuple where the first element is the sample tensor.

load_weights

load_weights(
    weights: Iterable[tuple[str, Tensor]],
) -> set[str]

FluxTransformerBlock

Bases: Module

attn instance-attribute

attn = FluxAttention(
    query_dim=dim,
    added_kv_proj_dim=dim,
    dim_head=attention_head_dim,
    heads=num_attention_heads,
    out_dim=dim,
    context_pre_only=False,
    bias=True,
    eps=eps,
    quant_config=quant_config,
    prefix=f"{prefix}.attn",
)

ff instance-attribute

ff = FeedForward(
    dim=dim,
    dim_out=dim,
    quant_config=quant_config,
    prefix=f"{prefix}.ff",
)

ff_context instance-attribute

ff_context = FeedForward(
    dim=dim,
    dim_out=dim,
    quant_config=quant_config,
    prefix=f"{prefix}.ff_context",
)

norm1 instance-attribute

norm1 = AdaLayerNormZero(
    dim, quant_config=quant_config, prefix=f"{prefix}.norm1"
)

norm1_context instance-attribute

norm1_context = AdaLayerNormZero(
    dim,
    quant_config=quant_config,
    prefix=f"{prefix}.norm1_context",
)

norm2 instance-attribute

norm2 = LayerNorm(dim, elementwise_affine=False, eps=1e-06)

norm2_context instance-attribute

norm2_context = LayerNorm(
    dim, elementwise_affine=False, eps=1e-06
)

forward

forward(
    hidden_states: Tensor,
    encoder_hidden_states: Tensor,
    temb: Tensor,
    image_rotary_emb: tuple[Tensor, Tensor] | None = None,
    joint_attention_kwargs: dict[str, Any] | None = None,
) -> tuple[Tensor, Tensor]