Skip to content

vllm_omni.diffusion.models.sd3.sd3_transformer

logger module-attribute

logger = init_logger(__name__)

FeedForward

Bases: Module

net instance-attribute

net = ModuleList([])

forward

forward(hidden_states: Tensor, *args, **kwargs) -> Tensor

GELU

Bases: Module

approximate instance-attribute

approximate = approximate

proj instance-attribute

proj = ColumnParallelLinear(dim_in, dim_out, bias=bias)

forward

forward(hidden_states)

SD3CrossAttention

Bases: Module

add_kv_proj instance-attribute

add_kv_proj = QKVParallelLinear(
    added_kv_proj_dim,
    head_size=inner_kv_dim // num_heads,
    total_num_heads=num_heads,
)

attn instance-attribute

attn = Attention(
    num_heads=num_heads,
    head_size=head_dim,
    softmax_scale=1.0 / head_dim**0.5,
    causal=False,
)

dim instance-attribute

dim = dim

eps instance-attribute

eps = eps

head_dim instance-attribute

head_dim = dim // num_heads

inner_dim instance-attribute

inner_dim = (
    out_dim if out_dim is not None else head_dim * num_heads
)

inner_kv_dim instance-attribute

inner_kv_dim = inner_dim

norm_added_k instance-attribute

norm_added_k = RMSNorm(head_dim, eps=eps)

norm_added_q instance-attribute

norm_added_q = RMSNorm(head_dim, eps=eps)

norm_k instance-attribute

norm_k = (
    RMSNorm(head_dim, eps=eps) if qk_norm else Identity()
)

norm_q instance-attribute

norm_q = (
    RMSNorm(head_dim, eps=eps) if qk_norm else Identity()
)

num_heads instance-attribute

num_heads = num_heads

parallel_attention instance-attribute

parallel_attention = parallel_attention

qk_norm instance-attribute

qk_norm = qk_norm

to_add_out instance-attribute

to_add_out = RowParallelLinear(
    inner_dim, dim, bias=out_bias
)

to_out instance-attribute

to_out = ModuleList([])

to_qkv instance-attribute

to_qkv = QKVParallelLinear(
    hidden_size=dim,
    head_size=head_dim,
    total_num_heads=num_heads,
)

forward

forward(
    hidden_states: Tensor,
    encoder_hidden_states: Tensor | None = None,
)

SD3PatchEmbed

Bases: Module

2D Image to Patch Embedding with support for SD3.

Parameters:

Name Type Description Default
patch_size `int`, defaults to `16`

The size of the patches.

16
in_channels `int`, defaults to `3`

The number of input channels.

3
embed_dim `int`, defaults to `768`

The output dimension of the embedding.

768

embed_dim instance-attribute

embed_dim = embed_dim

patch_size instance-attribute

patch_size = patch_size

proj instance-attribute

proj = Conv2d(
    in_channels,
    embed_dim,
    kernel_size=(patch_size, patch_size),
    stride=patch_size,
    bias=True,
)

forward

forward(latent)

SD3Transformer2DModel

Bases: Module

The Transformer model introduced in Stable Diffusion 3.

attention_head_dim instance-attribute

attention_head_dim = attention_head_dim

caption_projection_dim instance-attribute

caption_projection_dim = caption_projection_dim

context_embedder instance-attribute

context_embedder = ReplicatedLinear(
    joint_attention_dim, caption_projection_dim
)

dual_attention_layers instance-attribute

dual_attention_layers = (
    dual_attention_layers
    if hasattr(model_config, "dual_attention_layers")
    else ()
)

in_channels instance-attribute

in_channels = in_channels

inner_dim instance-attribute

inner_dim = num_attention_heads * attention_head_dim

joint_attention_dim instance-attribute

joint_attention_dim = joint_attention_dim

norm_out instance-attribute

norm_out = AdaLayerNormContinuous(
    inner_dim,
    inner_dim,
    elementwise_affine=False,
    eps=1e-06,
)

num_attention_heads instance-attribute

num_attention_heads = num_attention_heads

num_layers instance-attribute

num_layers = num_layers

out_channels instance-attribute

out_channels = out_channels

parallel_config instance-attribute

parallel_config = parallel_config

patch_size instance-attribute

patch_size = patch_size

pooled_projection_dim instance-attribute

pooled_projection_dim = pooled_projection_dim

pos_embed instance-attribute

pos_embed = PatchEmbed(
    height=sample_size,
    width=sample_size,
    patch_size=patch_size,
    in_channels=in_channels,
    embed_dim=inner_dim,
    pos_embed_max_size=pos_embed_max_size,
)

pos_embed_max_size instance-attribute

pos_embed_max_size = pos_embed_max_size

proj_out instance-attribute

proj_out = ReplicatedLinear(
    inner_dim,
    patch_size * patch_size * out_channels,
    bias=True,
)

qk_norm instance-attribute

qk_norm = (
    qk_norm if hasattr(model_config, "qk_norm") else ""
)

sample_size instance-attribute

sample_size = sample_size

time_text_embed instance-attribute

time_text_embed = CombinedTimestepTextProjEmbeddings(
    embedding_dim=inner_dim,
    pooled_projection_dim=pooled_projection_dim,
)

transformer_blocks instance-attribute

transformer_blocks = ModuleList(
    [
        (
            SD3TransformerBlock(
                dim=inner_dim,
                num_attention_heads=num_attention_heads,
                attention_head_dim=attention_head_dim,
                context_pre_only=i == num_layers - 1,
                qk_norm=qk_norm,
                use_dual_attention=True
                if i in dual_attention_layers
                else False,
            )
        )
        for i in (range(num_layers))
    ]
)

forward

forward(
    hidden_states: Tensor,
    encoder_hidden_states: Tensor,
    pooled_projections: Tensor,
    timestep: LongTensor,
    return_dict: bool = True,
) -> Tensor | Transformer2DModelOutput

The [SD3Transformer2DModel] forward method.

Parameters:

Name Type Description Default
hidden_states `torch.Tensor` of shape `(batch_size, image_sequence_length, in_channels)`

Input hidden_states.

required
encoder_hidden_states `torch.Tensor` of shape `(batch_size, text_sequence_length, joint_attention_dim)`

Conditional embeddings (embeddings computed from the input conditions such as prompts) to use.

required
pooled_projections `torch.Tensor` of shape `(batch_size, projection_dim)`

Embeddings projected from the embeddings of input conditions.

required
timestep `torch.LongTensor`

Used to indicate denoising step.

required
return_dict `bool`, *optional*, defaults to `True`

Whether or not to return a [~models.transformer_2d.Transformer2DModelOutput] instead of a plain tuple.

True

Returns:

Type Description
Tensor | Transformer2DModelOutput

If return_dict is True, an [~models.transformer_2d.Transformer2DModelOutput] is returned, otherwise a

Tensor | Transformer2DModelOutput

tuple where the first element is the sample tensor.

load_weights

load_weights(
    weights: Iterable[tuple[str, Tensor]],
) -> set[str]

SD3TransformerBlock

Bases: Module

A Transformer block following the MMDiT architecture, introduced in Stable Diffusion 3.

Reference: https://huggingface.co/papers/2403.03206

Parameters:

Name Type Description Default
dim `int`

The number of channels in the input and output.

required
num_attention_heads `int`

The number of heads to use for multi-head attention.

required
attention_head_dim `int`

The number of channels in each head.

required
context_pre_only `bool`

Boolean to determine if we should add some blocks associated with the processing of context conditions.

False

attn instance-attribute

attn = SD3CrossAttention(
    dim=dim,
    num_heads=num_attention_heads,
    head_dim=attention_head_dim,
    added_kv_proj_dim=dim,
    context_pre_only=context_pre_only,
    out_dim=dim,
    qk_norm=True if qk_norm == "rms_norm" else False,
    eps=1e-06,
)

attn2 instance-attribute

attn2 = SD3CrossAttention(
    dim=dim,
    num_heads=num_attention_heads,
    head_dim=attention_head_dim,
    added_kv_proj_dim=None,
    context_pre_only=True,
    out_dim=dim,
    qk_norm=True if qk_norm == "rms_norm" else False,
    eps=1e-06,
)

context_pre_only instance-attribute

context_pre_only = context_pre_only

ff instance-attribute

ff = FeedForward(
    dim=dim, dim_out=dim, activation_fn="gelu-approximate"
)

ff_context instance-attribute

ff_context = FeedForward(
    dim=dim, dim_out=dim, activation_fn="gelu-approximate"
)

norm1 instance-attribute

norm1 = SD35AdaLayerNormZeroX(dim)

norm1_context instance-attribute

norm1_context = AdaLayerNormContinuous(
    dim,
    dim,
    elementwise_affine=False,
    eps=1e-06,
    bias=True,
    norm_type="layer_norm",
)

norm2 instance-attribute

norm2 = LayerNorm(dim, elementwise_affine=False, eps=1e-06)

norm2_context instance-attribute

norm2_context = LayerNorm(
    dim, elementwise_affine=False, eps=1e-06
)

use_dual_attention instance-attribute

use_dual_attention = use_dual_attention

forward

forward(
    hidden_states: FloatTensor,
    encoder_hidden_states: FloatTensor,
    temb: FloatTensor,
) -> tuple[Tensor, Tensor]