Skip to content

vllm_omni.diffusion.models.omnigen2.omnigen2_transformer

logger module-attribute

logger = getLogger(__name__)

Lumina2CombinedTimestepCaptionEmbedding

Bases: Module

caption_embedder instance-attribute

caption_embedder = Sequential(
    RMSNorm(text_feat_dim, eps=norm_eps),
    Linear(text_feat_dim, hidden_size, bias=True),
)

time_proj instance-attribute

time_proj = Timesteps(
    num_channels=frequency_embedding_size,
    flip_sin_to_cos=True,
    downscale_freq_shift=0.0,
    scale=timestep_scale,
)

timestep_embedder instance-attribute

timestep_embedder = TimestepEmbedding(
    in_channels=frequency_embedding_size,
    time_embed_dim=min(hidden_size, 1024),
)

forward

forward(
    timestep: Tensor,
    text_hidden_states: Tensor,
    dtype: dtype,
) -> tuple[Tensor, Tensor]

LuminaFeedForward

Bases: Module

A feed-forward layer.

Parameters:

Name Type Description Default
dim `int`

The dimensionality of the input and output tensors.

required
inner_dim `int`

The intermediate dimension of the feedforward layer.

required
multiple_of `int`, *optional*

Value to ensure hidden dimension is a multiple of this value.

256
ffn_dim_multiplier float, *optional*

Custom multiplier for hidden dimension. Defaults to None.

None

act_fn instance-attribute

act_fn = get_act_and_mul_fn('silu')

down_proj instance-attribute

down_proj = RowParallelLinear(
    inner_dim,
    dim,
    bias=False,
    input_is_parallel=True,
    return_bias=False,
    quant_config=quant_config,
    prefix=f"{prefix}.down_proj",
)

gate_up_proj instance-attribute

gate_up_proj = MergedColumnParallelLinear(
    dim,
    [inner_dim, inner_dim],
    bias=False,
    return_bias=False,
    quant_config=quant_config,
    prefix=f"{prefix}.gate_up_proj",
)

forward

forward(x)

LuminaLayerNormContinuous

Bases: Module

linear_1 instance-attribute

linear_1 = Linear(
    conditioning_embedding_dim, embedding_dim, bias=bias
)

linear_2 instance-attribute

linear_2 = None

norm instance-attribute

norm = LayerNorm(
    embedding_dim, eps, elementwise_affine, bias
)

silu instance-attribute

silu = SiLU()

forward

forward(
    x: Tensor, conditioning_embedding: Tensor
) -> Tensor

LuminaRMSNormZero

Bases: Module

Norm layer adaptive RMS normalization zero.

Parameters:

Name Type Description Default
embedding_dim `int`

The size of each embedding vector.

required

linear instance-attribute

linear = Linear(
    min(embedding_dim, 1024), 4 * embedding_dim, bias=True
)

norm instance-attribute

norm = RMSNorm(embedding_dim, eps=norm_eps)

silu instance-attribute

silu = SiLU()

forward

forward(
    x: Tensor, emb: Tensor | None = None
) -> tuple[Tensor, Tensor, Tensor, Tensor]

OmniGen2Attention

Bases: Module

attn instance-attribute

attn = Attention(
    num_heads=num_heads,
    head_size=head_dim,
    softmax_scale=1.0 / head_dim**0.5,
    causal=False,
)

dim instance-attribute

dim = dim

eps instance-attribute

eps = eps

head_dim instance-attribute

head_dim = dim // num_heads

norm_k instance-attribute

norm_k = RMSNorm(head_dim, eps=eps)

norm_q instance-attribute

norm_q = RMSNorm(head_dim, eps=eps)

num_heads instance-attribute

num_heads = num_heads

num_kv_heads instance-attribute

num_kv_heads = num_kv_heads

to_out instance-attribute

to_out = ModuleList(
    [
        RowParallelLinear(
            dim,
            dim,
            bias=False,
            input_is_parallel=False,
            quant_config=quant_config,
            return_bias=False,
            prefix=f"{prefix}.to_out.0",
        )
    ]
)

to_qkv instance-attribute

to_qkv = QKVParallelLinear(
    hidden_size=dim,
    head_size=head_dim,
    total_num_heads=num_heads,
    total_num_kv_heads=num_kv_heads,
    disable_tp=True,
    bias=False,
    quant_config=quant_config,
    prefix=f"{prefix}.to_qkv",
)

forward

forward(
    hidden_states: Tensor,
    attention_mask: Tensor | None = None,
    image_rotary_emb: Tensor | None = None,
) -> Tensor

Process attention computation with flash attention.

Parameters:

Name Type Description Default
hidden_states Tensor

Hidden states tensor of shape (batch_size, seq_len, hidden_dim)

required
attention_mask Tensor | None

Optional attention mask tensor

None
image_rotary_emb Tensor | None

Optional rotary embeddings for image tokens

None

Returns:

Type Description
Tensor

torch.Tensor: Processed hidden states after attention computation

OmniGen2RotaryPosEmbed

Bases: Module

axes_dim instance-attribute

axes_dim = axes_dim

axes_lens instance-attribute

axes_lens = axes_lens

patch_size instance-attribute

patch_size = patch_size

theta instance-attribute

theta = theta

forward

forward(
    freqs_cis,
    attention_mask,
    l_effective_ref_img_len,
    l_effective_img_len,
    ref_img_sizes,
    img_sizes,
    device,
)

get_freqs_cis staticmethod

get_freqs_cis(
    axes_dim: tuple[int, int, int],
    axes_lens: tuple[int, int, int],
    theta: int,
) -> list[Tensor]

OmniGen2Transformer2DModel

Bases: Module

OmniGen2 Transformer 2D Model.

A transformer-based diffusion model for image generation with: - Patch-based image processing - Rotary position embeddings - Multi-head attention - Conditional generation support

Parameters:

Name Type Description Default
patch_size int

Size of image patches

2
in_channels int

Number of input channels

16
out_channels int | None

Number of output channels (defaults to in_channels)

None
hidden_size int

Size of hidden layers

2520
num_layers int

Number of transformer layers

32
num_refiner_layers int

Number of refiner layers

2
num_attention_heads int

Number of attention heads

21
num_kv_heads int

Number of key-value heads

7
multiple_of int

Multiple of which the hidden dimension should be

256
ffn_dim_multiplier float | None

Multiplier for feed-forward network dimension

None
norm_eps float

Epsilon value for normalization layers

1e-05
axes_dim_rope tuple[int, int, int]

Dimensions for rotary position embeddings

(40, 40, 40)
axes_lens tuple[int, int, int]

Lengths for rotary position embeddings

(1024, 1664, 1664)
text_feat_dim int

Dimension of text features

2048
timestep_scale float

Scale factor for timestep embeddings

1000.0

config instance-attribute

config = SimpleNamespace(
    patch_size=patch_size,
    in_channels=in_channels,
    out_channels=out_channels or in_channels,
    hidden_size=hidden_size,
    axes_dim_rope=axes_dim_rope,
    axes_lens=axes_lens,
)

context_refiner instance-attribute

context_refiner = ModuleList(
    [
        (
            OmniGen2TransformerBlock(
                hidden_size,
                num_attention_heads,
                num_kv_heads,
                multiple_of,
                ffn_dim_multiplier,
                norm_eps,
                modulation=False,
                quant_config=quant_config,
                prefix=f"context_refiner.{i}",
            )
        )
        for i in (range(num_refiner_layers))
    ]
)

image_index_embedding instance-attribute

image_index_embedding = Parameter(randn(5, hidden_size))

layers instance-attribute

layers = ModuleList(
    [
        (
            OmniGen2TransformerBlock(
                hidden_size,
                num_attention_heads,
                num_kv_heads,
                multiple_of,
                ffn_dim_multiplier,
                norm_eps,
                modulation=True,
                quant_config=quant_config,
                prefix=f"layers.{i}",
            )
        )
        for i in (range(num_layers))
    ]
)

noise_refiner instance-attribute

noise_refiner = ModuleList(
    [
        (
            OmniGen2TransformerBlock(
                hidden_size,
                num_attention_heads,
                num_kv_heads,
                multiple_of,
                ffn_dim_multiplier,
                norm_eps,
                modulation=True,
                quant_config=quant_config,
                prefix=f"noise_refiner.{i}",
            )
        )
        for i in (range(num_refiner_layers))
    ]
)

norm_out instance-attribute

norm_out = LuminaLayerNormContinuous(
    embedding_dim=hidden_size,
    conditioning_embedding_dim=min(hidden_size, 1024),
    elementwise_affine=False,
    eps=1e-06,
    bias=True,
    out_dim=patch_size * patch_size * out_channels,
)

out_channels instance-attribute

out_channels = out_channels

ref_image_patch_embedder instance-attribute

ref_image_patch_embedder = Linear(
    in_features=patch_size * patch_size * in_channels,
    out_features=hidden_size,
)

ref_image_refiner instance-attribute

ref_image_refiner = ModuleList(
    [
        (
            OmniGen2TransformerBlock(
                hidden_size,
                num_attention_heads,
                num_kv_heads,
                multiple_of,
                ffn_dim_multiplier,
                norm_eps,
                modulation=True,
                quant_config=quant_config,
                prefix=f"ref_image_refiner.{i}",
            )
        )
        for i in (range(num_refiner_layers))
    ]
)

rope_embedder instance-attribute

rope_embedder = OmniGen2RotaryPosEmbed(
    theta=10000,
    axes_dim=axes_dim_rope,
    axes_lens=axes_lens,
    patch_size=patch_size,
)

time_caption_embed instance-attribute

time_caption_embed = (
    Lumina2CombinedTimestepCaptionEmbedding(
        hidden_size=hidden_size,
        text_feat_dim=text_feat_dim,
        norm_eps=norm_eps,
        timestep_scale=timestep_scale,
    )
)

x_embedder instance-attribute

x_embedder = Linear(
    in_features=patch_size * patch_size * in_channels,
    out_features=hidden_size,
)

flat_and_pad_to_seq

flat_and_pad_to_seq(hidden_states, ref_image_hidden_states)

forward

forward(
    hidden_states: Tensor | list[Tensor],
    timestep: Tensor,
    text_hidden_states: Tensor,
    freqs_cis: Tensor,
    text_attention_mask: Tensor,
    ref_image_hidden_states: list[list[Tensor]]
    | None = None,
    return_dict: bool = False,
) -> Tensor | Transformer2DModelOutput

img_patch_embed_and_refine

img_patch_embed_and_refine(
    hidden_states,
    ref_image_hidden_states,
    padded_img_mask,
    padded_ref_img_mask,
    noise_rotary_emb,
    ref_img_rotary_emb,
    l_effective_ref_img_len,
    l_effective_img_len,
    temb,
)

load_weights

load_weights(
    weights: Iterable[tuple[str, Tensor]],
) -> set[str]

OmniGen2TransformerBlock

Bases: Module

Transformer block for OmniGen2 model.

This block implements a transformer layer with: - Multi-head attention with flash attention - Feed-forward network with SwiGLU activation - RMS normalization - Optional modulation for conditional generation

Parameters:

Name Type Description Default
dim int

Dimension of the input and output tensors

required
num_attention_heads int

Number of attention heads

required
num_kv_heads int

Number of key-value heads

required
multiple_of int

Multiple of which the hidden dimension should be

required
ffn_dim_multiplier float

Multiplier for the feed-forward network dimension

required
norm_eps float

Epsilon value for normalization layers

required
modulation bool

Whether to use modulation for conditional generation

True

attn instance-attribute

attn = OmniGen2Attention(
    dim=dim,
    num_heads=num_attention_heads,
    num_kv_heads=num_kv_heads,
    eps=1e-05,
    quant_config=quant_config,
    prefix=f"{prefix}.attn",
)

feed_forward instance-attribute

feed_forward = LuminaFeedForward(
    dim=dim,
    inner_dim=4 * dim,
    multiple_of=multiple_of,
    ffn_dim_multiplier=ffn_dim_multiplier,
    quant_config=quant_config,
    prefix=f"{prefix}.feed_forward",
)

ffn_norm1 instance-attribute

ffn_norm1 = RMSNorm(dim, eps=norm_eps)

ffn_norm2 instance-attribute

ffn_norm2 = RMSNorm(dim, eps=norm_eps)

head_dim instance-attribute

head_dim = dim // num_attention_heads

modulation instance-attribute

modulation = modulation

norm1 instance-attribute

norm1 = LuminaRMSNormZero(
    embedding_dim=dim,
    norm_eps=norm_eps,
    norm_elementwise_affine=True,
    quant_config=quant_config,
    prefix=f"{prefix}.norm1",
)

norm2 instance-attribute

norm2 = RMSNorm(dim, eps=norm_eps)

forward

forward(
    hidden_states: Tensor,
    attention_mask: Tensor,
    image_rotary_emb: Tensor,
    temb: Tensor | None = None,
) -> Tensor

Forward pass of the transformer block.

Parameters:

Name Type Description Default
hidden_states Tensor

Input hidden states tensor

required
attention_mask Tensor

Attention mask tensor

required
image_rotary_emb Tensor

Rotary embeddings for image tokens

required
temb Tensor | None

Optional timestep embedding tensor

None

Returns:

Type Description
Tensor

torch.Tensor: Output hidden states after transformer block processing

TimestepEmbedding

Bases: Module

act instance-attribute

act = get_activation(act_fn)

cond_proj instance-attribute

cond_proj = Linear(cond_proj_dim, in_channels, bias=False)

linear_1 instance-attribute

linear_1 = Linear(
    in_channels, time_embed_dim, sample_proj_bias
)

linear_2 instance-attribute

linear_2 = Linear(
    time_embed_dim, time_embed_dim_out, sample_proj_bias
)

post_act instance-attribute

post_act = None

forward

forward(sample, condition=None)

apply_rotary_emb

apply_rotary_emb(
    x: Tensor,
    freqs_cis: Tensor | tuple[Tensor],
    use_real: bool = True,
    use_real_unbind_dim: int = -1,
) -> tuple[Tensor, Tensor]

Apply rotary embeddings to input tensors using the given frequency tensor. This function applies rotary embeddings to the given query or key 'x' tensors using the provided frequency tensor 'freqs_cis'. The input tensors are reshaped as complex numbers, and the frequency tensor is reshaped for broadcasting compatibility. The resulting tensors contain rotary embeddings and are returned as real tensors.

Parameters:

Name Type Description Default
x `torch.Tensor`

Query or key tensor to apply rotary embeddings. [B, H, S, D] xk (torch.Tensor): Key tensor to apply

required
freqs_cis `Tuple[torch.Tensor]`

Precomputed frequency tensor for complex exponentials. ([S, D], [S, D],)

required

Returns:

Type Description
tuple[Tensor, Tensor]

Tuple[torch.Tensor, torch.Tensor]: Tuple of modified query tensor and key tensor with rotary embeddings.