vllm_omni.diffusion.models.flux.flux_transformer ¶
ColumnParallelApproxGELU ¶
FeedForward ¶
FluxAttention ¶
Bases: Module
add_kv_proj instance-attribute ¶
add_kv_proj = QKVParallelLinear(
hidden_size=added_kv_proj_dim,
head_size=head_dim,
total_num_heads=heads,
bias=added_proj_bias,
quant_config=quant_config,
prefix=f"{prefix}.add_kv_proj",
)
attn instance-attribute ¶
attn = Attention(
num_heads=num_heads,
head_size=head_dim,
softmax_scale=1.0 / head_dim**0.5,
causal=False,
num_kv_heads=num_kv_heads,
)
to_add_out instance-attribute ¶
to_add_out = RowParallelLinear(
inner_dim,
query_dim,
bias=out_bias,
input_is_parallel=True,
return_bias=False,
quant_config=quant_config,
prefix=f"{prefix}.to_add_out",
)
to_out instance-attribute ¶
to_out = ModuleList(
[
RowParallelLinear(
inner_dim,
out_dim,
bias=out_bias,
input_is_parallel=True,
return_bias=False,
quant_config=quant_config,
prefix=f"{prefix}.to_out.0",
),
Dropout(dropout),
]
)
to_qkv instance-attribute ¶
to_qkv = QKVParallelLinear(
hidden_size=query_dim,
head_size=head_dim,
total_num_heads=heads,
bias=bias,
quant_config=quant_config,
prefix=f"{prefix}.to_qkv",
)
FluxKontextTransformer2DModel ¶
Bases: FluxTransformer2DModel
forward ¶
forward(
hidden_states: Tensor,
encoder_hidden_states: Tensor = None,
pooled_projections: Tensor = None,
timestep: LongTensor = None,
img_ids: Tensor = None,
txt_ids: Tensor = None,
guidance: Tensor | None = None,
joint_attention_kwargs: dict[str, Any] | None = None,
return_dict: bool = True,
) -> Tensor | Transformer2DModelOutput
FluxPosEmbed ¶
FluxSingleTransformerBlock ¶
Bases: Module
attn instance-attribute ¶
attn = FluxAttention(
query_dim=dim,
dim_head=attention_head_dim,
heads=num_attention_heads,
out_dim=dim,
bias=True,
eps=1e-06,
pre_only=True,
quant_config=quant_config,
prefix=f"{prefix}.attn",
)
norm instance-attribute ¶
norm = AdaLayerNormZeroSingle(
dim,
quant_config=_safe_quant_config(quant_config),
prefix=f"{prefix}.norm",
)
proj_mlp instance-attribute ¶
proj_mlp = ReplicatedLinear(
dim,
mlp_hidden_dim,
bias=True,
return_bias=False,
quant_config=quant_config,
prefix=f"{prefix}.proj_mlp",
)
proj_out instance-attribute ¶
proj_out = ReplicatedLinear(
dim + mlp_hidden_dim,
dim,
bias=True,
return_bias=False,
quant_config=quant_config,
prefix=f"{prefix}.proj_out",
)
FluxTransformer2DModel ¶
Bases: Module
The Transformer model introduced in Flux.
Parameters:
| Name | Type | Description | Default |
|---|---|---|---|
od_config | `OmniDiffusionConfig` | The configuration for the model. | None |
patch_size | `int`, defaults to `1` | Patch size to turn the input data into small patches. | 1 |
in_channels | `int`, defaults to `64` | The number of channels in the input. | 64 |
out_channels | `int`, *optional*, defaults to `None` | The number of channels in the output. If not specified, it defaults to | None |
num_layers | `int`, defaults to `19` | The number of layers of dual stream DiT blocks to use. | 19 |
num_single_layers | `int`, defaults to `38` | The number of layers of single stream DiT blocks to use. | 38 |
attention_head_dim | `int`, defaults to `128` | The number of dimensions to use for each attention head. | 128 |
num_attention_heads | `int`, defaults to `24` | The number of attention heads to use. | 24 |
joint_attention_dim | `int`, defaults to `4096` | The number of dimensions to use for the joint attention (embedding/channel dimension of | 4096 |
pooled_projection_dim | `int`, defaults to `768` | The number of dimensions to use for the pooled projection. | 768 |
guidance_embeds | `bool`, defaults to `False` | Whether to use guidance embeddings for guidance-distilled variant of the model. | True |
axes_dims_rope | `Tuple[int]`, defaults to `(16, 56, 56)` | The dimensions to use for the rotary positional embeddings. | (16, 56, 56) |
norm_out instance-attribute ¶
norm_out = AdaLayerNormContinuous(
inner_dim,
inner_dim,
elementwise_affine=False,
eps=1e-06,
quant_config=_safe_quant_config(quant_config),
prefix="norm_out",
)
packed_modules_mapping class-attribute instance-attribute ¶
packed_modules_mapping = {
"to_qkv": ["to_q", "to_k", "to_v"],
"add_kv_proj": [
"add_q_proj",
"add_k_proj",
"add_v_proj",
],
}
proj_out instance-attribute ¶
single_transformer_blocks instance-attribute ¶
single_transformer_blocks = ModuleList(
[
(
FluxSingleTransformerBlock(
dim=inner_dim,
num_attention_heads=num_attention_heads,
attention_head_dim=attention_head_dim,
quant_config=quant_config,
prefix=f"single_transformer_blocks.{i}",
)
)
for i in (range(num_single_layers))
]
)
time_text_embed instance-attribute ¶
time_text_embed = text_time_guidance_cls(
embedding_dim=inner_dim,
pooled_projection_dim=pooled_projection_dim,
)
transformer_blocks instance-attribute ¶
transformer_blocks = ModuleList(
[
(
FluxTransformerBlock(
dim=inner_dim,
num_attention_heads=num_attention_heads,
attention_head_dim=attention_head_dim,
quant_config=_safe_quant_config(
quant_config
),
prefix=f"transformer_blocks.{i}",
)
)
for i in (range(num_layers))
]
)
forward ¶
forward(
hidden_states: Tensor,
encoder_hidden_states: Tensor = None,
pooled_projections: Tensor = None,
timestep: LongTensor = None,
img_ids: Tensor = None,
txt_ids: Tensor = None,
guidance: Tensor | None = None,
joint_attention_kwargs: dict[str, Any] | None = None,
return_dict: bool = True,
) -> Tensor | Transformer2DModelOutput
The [FluxTransformer2DModel] forward method.
Parameters:
| Name | Type | Description | Default |
|---|---|---|---|
hidden_states | `torch.Tensor` of shape `(batch_size, image_sequence_length, in_channels)` | Input | required |
encoder_hidden_states | `torch.Tensor` of shape `(batch_size, text_sequence_length, joint_attention_dim)` | Conditional embeddings (embeddings computed from the input conditions such as prompts) to use. | None |
pooled_projections | `torch.Tensor` of shape `(batch_size, projection_dim)` | Embeddings projected from the embeddings of input conditions. | None |
timestep | `torch.LongTensor` | Used to indicate denoising step. | None |
img_ids | Tensor | ( | None |
txt_ids | `torch.Tensor` | The position ids for text tokens. | None |
guidance | `torch.Tensor` | Guidance embeddings for guidance-distilled variant of the model. | None |
joint_attention_kwargs | `dict`, *optional* | A kwargs dictionary that if specified is passed along to the | None |
return_dict | `bool`, *optional*, defaults to `True` | Whether or not to return a [ | True |
Returns:
| Type | Description |
|---|---|
Tensor | Transformer2DModelOutput | If |
Tensor | Transformer2DModelOutput |
|
FluxTransformerBlock ¶
Bases: Module
attn instance-attribute ¶
attn = FluxAttention(
query_dim=dim,
added_kv_proj_dim=dim,
dim_head=attention_head_dim,
heads=num_attention_heads,
out_dim=dim,
context_pre_only=False,
bias=True,
eps=eps,
quant_config=quant_config,
prefix=f"{prefix}.attn",
)
ff instance-attribute ¶
ff = FeedForward(
dim=dim,
dim_out=dim,
quant_config=quant_config,
prefix=f"{prefix}.ff",
)
ff_context instance-attribute ¶
ff_context = FeedForward(
dim=dim,
dim_out=dim,
quant_config=quant_config,
prefix=f"{prefix}.ff_context",
)
norm1 instance-attribute ¶
norm1 = AdaLayerNormZero(
dim, quant_config=quant_config, prefix=f"{prefix}.norm1"
)
norm1_context instance-attribute ¶
norm1_context = AdaLayerNormZero(
dim,
quant_config=quant_config,
prefix=f"{prefix}.norm1_context",
)