Skip to content

vllm_omni.diffusion.models.sdxl.sdxl_unet

logger module-attribute

logger = init_logger(__name__)

SDXLAddTimestepEmbedding

Bases: Module

act instance-attribute

act = nn.SiLU()

linear_1 instance-attribute

linear_1 = nn.Linear(
    text_embed_dim + addition_time_embed_dim * 6,
    time_embed_dim,
)

linear_2 instance-attribute

linear_2 = nn.Linear(time_embed_dim, time_embed_dim)

forward

forward(text_embeds: Tensor, time_ids: Tensor) -> Tensor

SDXLBasicTransformerBlock

Bases: Module

attention_head_dim instance-attribute

attention_head_dim = attention_head_dim

attn1 instance-attribute

attn1 = SDXLSelfAttention(
    dim=dim,
    num_attention_heads=num_attention_heads,
    attention_head_dim=attention_head_dim,
    prefix=f"{prefix}.attn1",
)

attn2 instance-attribute

attn2 = SDXLCrossAttention(
    dim=dim,
    cross_attention_dim=cross_attention_dim,
    num_attention_heads=num_attention_heads,
    attention_head_dim=attention_head_dim,
    prefix=f"{prefix}.attn2",
)

ff instance-attribute

ff = SDXLFeedForward(dim=dim, inner_dim=dim * 4)

norm1 instance-attribute

norm1 = nn.LayerNorm(dim, elementwise_affine=True)

norm2 instance-attribute

norm2 = nn.LayerNorm(dim, elementwise_affine=True)

norm3 instance-attribute

norm3 = nn.LayerNorm(dim, elementwise_affine=True)

num_attention_heads instance-attribute

num_attention_heads = num_attention_heads

forward

forward(
    hidden_states: Tensor,
    encoder_hidden_states: Tensor,
    temb: None = None,
) -> Tensor

SDXLCrossAttention

Bases: Module

attention instance-attribute

attention = Attention(
    num_heads=num_attention_heads,
    head_size=attention_head_dim,
    causal=False,
    softmax_scale=1.0 / math.sqrt(attention_head_dim),
    role="cross",
    prefix=prefix,
    skip_sequence_parallel=True,
)

to_k instance-attribute

to_k = ColumnParallelLinear(
    cross_attention_dim, inner_dim, bias=False
)

to_out instance-attribute

to_out = nn.ModuleList(
    [RowParallelLinear(inner_dim, dim, bias=True)]
)

to_q instance-attribute

to_q = ColumnParallelLinear(dim, inner_dim, bias=False)

to_v instance-attribute

to_v = ColumnParallelLinear(
    cross_attention_dim, inner_dim, bias=False
)

SDXLCrossAttnDownBlock2D

Bases: Module

attentions instance-attribute

attentions = nn.ModuleList()

downsamplers instance-attribute

downsamplers = None

resnets instance-attribute

resnets = nn.ModuleList()

forward

forward(
    hidden_states: Tensor,
    temb: Tensor,
    encoder_hidden_states: Tensor,
) -> tuple[Tensor, list[Tensor]]

SDXLCrossAttnUpBlock2D

Bases: Module

attentions instance-attribute

attentions = nn.ModuleList()

resnets instance-attribute

resnets = nn.ModuleList()

upsamplers instance-attribute

upsamplers = None

forward

forward(
    hidden_states: Tensor,
    temb: Tensor,
    encoder_hidden_states: Tensor,
    res_hidden_states_tuple: tuple[Tensor, ...],
) -> Tensor

SDXLDownBlock2D

Bases: Module

downsamplers instance-attribute

downsamplers = None

resnets instance-attribute

resnets = nn.ModuleList()

forward

forward(
    hidden_states: Tensor, temb: Tensor
) -> tuple[Tensor, list[Tensor]]

SDXLDownsample2D

Bases: Module

conv instance-attribute

conv = nn.Conv2d(
    channels, channels, kernel_size=3, stride=2, padding=1
)

forward

forward(hidden_states: Tensor) -> Tensor

SDXLFeedForward

Bases: Module

geglu instance-attribute

geglu = SDXLGEGLU(dim, inner_dim)

out_proj instance-attribute

out_proj = RowParallelLinear(inner_dim, dim, bias=True)

forward

forward(hidden_states: Tensor) -> Tensor

SDXLGEGLU

Bases: Module

proj instance-attribute

proj = MergedColumnParallelLinear(
    dim_in, [dim_out, dim_out], bias=True
)

forward

forward(hidden_states: Tensor) -> Tensor

SDXLResnetBlock2D

Bases: Module

conv1 instance-attribute

conv1 = nn.Conv2d(
    in_channels, out_channels, kernel_size=3, padding=1
)

conv2 instance-attribute

conv2 = nn.Conv2d(
    out_channels, out_channels, kernel_size=3, padding=1
)

conv_shortcut instance-attribute

conv_shortcut = None

nonlinearity instance-attribute

nonlinearity = nn.SiLU()

norm1 instance-attribute

norm1 = nn.GroupNorm(groups, in_channels, eps=1e-05)

norm2 instance-attribute

norm2 = nn.GroupNorm(groups, out_channels, eps=1e-05)

time_emb_proj instance-attribute

time_emb_proj = nn.Linear(time_embed_dim, out_channels)

forward

forward(hidden_states: Tensor, temb: Tensor) -> Tensor

SDXLSelfAttention

Bases: Module

attention instance-attribute

attention = Attention(
    num_heads=num_attention_heads,
    head_size=attention_head_dim,
    causal=False,
    softmax_scale=1.0 / math.sqrt(attention_head_dim),
    role="self",
    prefix=prefix,
)

to_out instance-attribute

to_out = nn.ModuleList(
    [RowParallelLinear(inner_dim, dim, bias=True)]
)

to_qkv instance-attribute

to_qkv = QKVParallelLinear(
    hidden_size=dim,
    head_size=attention_head_dim,
    total_num_heads=num_attention_heads,
    bias=False,
)

SDXLTimestepEmbedding

Bases: Module

act instance-attribute

act = nn.SiLU()

linear_1 instance-attribute

linear_1 = nn.Linear(in_channels, time_embed_dim)

linear_2 instance-attribute

linear_2 = nn.Linear(time_embed_dim, time_embed_dim)

forward

forward(sample: Tensor) -> Tensor

SDXLTransformer2DModel

Bases: Module

norm instance-attribute

norm = nn.GroupNorm(32, in_channels, eps=1e-06, affine=True)

proj_in instance-attribute

proj_in = nn.Linear(in_channels, inner_dim)

proj_out instance-attribute

proj_out = nn.Linear(inner_dim, in_channels)

transformer_blocks instance-attribute

transformer_blocks = nn.ModuleList(
    [
        (
            SDXLBasicTransformerBlock(
                dim=inner_dim,
                num_attention_heads=num_attention_heads,
                attention_head_dim=attention_head_dim,
                cross_attention_dim=cross_attention_dim,
                prefix=f"{prefix}.transformer_blocks.{i}",
            )
        )
        for i in (range(num_layers))
    ]
)

forward

forward(
    hidden_states: Tensor, encoder_hidden_states: Tensor
) -> Tensor

SDXLUNet2DConditionModel

Bases: Module

add_embedding instance-attribute

add_embedding = SDXLAddTimestepEmbedding(
    addition_time_embed_dim=addition_time_embed_dim,
    text_embed_dim=1280,
    time_embed_dim=time_embed_dim,
)

conv_act instance-attribute

conv_act = nn.SiLU()

conv_in instance-attribute

conv_in = nn.Conv2d(
    4, model_channels, kernel_size=3, padding=1
)

conv_norm_out instance-attribute

conv_norm_out = nn.GroupNorm(
    32, block_out_channels[0], eps=1e-05
)

conv_out instance-attribute

conv_out = nn.Conv2d(
    block_out_channels[0], 4, kernel_size=3, padding=1
)

down_blocks instance-attribute

down_blocks = nn.ModuleList()

in_channels instance-attribute

in_channels = 4

mid_block instance-attribute

mid_block = SDXLUNetMidBlock2DCrossAttn(
    in_channels=block_out_channels[2],
    time_embed_dim=time_embed_dim,
    num_attention_heads=block_out_channels[2]
    // num_head_channels,
    cross_attention_dim=cross_attention_dim,
    transformer_layers_per_block=transformer_layers_per_block[
        2
    ],
    prefix="mid_block",
)

od_config instance-attribute

od_config = od_config

out_channels instance-attribute

out_channels = 4

time_embedding instance-attribute

time_embedding = SDXLTimestepEmbedding(
    model_channels, time_embed_dim
)

time_proj_dim instance-attribute

time_proj_dim = model_channels

up_blocks instance-attribute

up_blocks = nn.ModuleList()

forward

forward(
    hidden_states: Tensor,
    timestep: Tensor,
    encoder_hidden_states: Tensor,
    added_cond_kwargs: dict,
    return_dict: bool = False,
) -> tuple[Tensor]

load_weights

load_weights(
    weights: Iterable[tuple[str, Tensor]],
) -> set[str]

SDXLUNetMidBlock2DCrossAttn

Bases: Module

attentions instance-attribute

attentions = nn.ModuleList(
    [
        SDXLTransformer2DModel(
            num_attention_heads=num_attention_heads,
            attention_head_dim=attention_head_dim,
            in_channels=in_channels,
            num_layers=transformer_layers_per_block,
            cross_attention_dim=cross_attention_dim,
            prefix=f"{prefix}.attentions.0",
        )
    ]
)

resnets instance-attribute

resnets = nn.ModuleList(
    [
        SDXLResnetBlock2D(
            in_channels, in_channels, time_embed_dim
        ),
        SDXLResnetBlock2D(
            in_channels, in_channels, time_embed_dim
        ),
    ]
)

forward

forward(
    hidden_states: Tensor,
    temb: Tensor,
    encoder_hidden_states: Tensor,
) -> Tensor

SDXLUpBlock2D

Bases: Module

resnets instance-attribute

resnets = nn.ModuleList()

upsamplers instance-attribute

upsamplers = None

forward

forward(
    hidden_states: Tensor,
    temb: Tensor,
    res_hidden_states_tuple: tuple[Tensor, ...],
) -> Tensor

SDXLUpsample2D

Bases: Module

conv instance-attribute

conv = nn.Conv2d(
    channels, channels, kernel_size=3, padding=1
)

forward

forward(hidden_states: Tensor) -> Tensor

get_timestep_embedding

get_timestep_embedding(
    timesteps: Tensor,
    embedding_dim: int,
    max_period: int = 10000,
) -> Tensor