Skip to content

vllm_omni.model_executor.models.common.ming.dit

Attention

Bases: Module

attn_mask_enabled instance-attribute

attn_mask_enabled = attn_mask_enabled

dim instance-attribute

dim = dim

dropout instance-attribute

dropout = dropout

heads instance-attribute

heads = heads

inner_dim instance-attribute

inner_dim = dim_head * heads

k_norm instance-attribute

k_norm = None

pe_attn_head instance-attribute

pe_attn_head = pe_attn_head

q_norm instance-attribute

q_norm = None

to_k instance-attribute

to_k = nn.Linear(dim, self.inner_dim)

to_out instance-attribute

to_out = nn.ModuleList([])

to_q instance-attribute

to_q = nn.Linear(dim, self.inner_dim)

to_v instance-attribute

to_v = nn.Linear(dim, self.inner_dim)

forward

forward(x, mask=None, rope=None)

CondEmbedder

Bases: Module

cond_embedder instance-attribute

cond_embedder = nn.Linear(input_feature_size, hidden_size)

forward

forward(llm_cond)

DiTBlock

Bases: Module

attn instance-attribute

attn = Attention(
    dim=hidden_size,
    heads=num_heads,
    dim_head=hidden_size // num_heads,
    dropout=dropout,
    qk_norm=qk_norm,
    pe_attn_head=pe_attn_head,
    attn_mask_enabled=attn_mask_enabled,
)

mlp instance-attribute

mlp = FeedForward(
    dim=hidden_size,
    mult=mlp_ratio,
    dropout=dropout,
    approximate="tanh",
)

norm1 instance-attribute

norm1 = RMSNorm(hidden_size)

norm2 instance-attribute

norm2 = RMSNorm(hidden_size)

forward

forward(x, mask, rope)

FeedForward

Bases: Module

ff instance-attribute

ff = nn.Sequential(
    project_in,
    nn.Dropout(dropout),
    nn.Linear(inner_dim, dim_out),
)

forward

forward(x)

FinalLayer

Bases: Module

linear instance-attribute

linear = nn.Linear(hidden_size, out_channels, bias=True)

norm_final instance-attribute

norm_final = RMSNorm(hidden_size)

forward

forward(x)

RMSNorm

Bases: Module

eps instance-attribute

eps = eps

weight instance-attribute

weight = nn.Parameter(torch.ones(dim))

forward

forward(x)

get_epss_timesteps

get_epss_timesteps(n, device, dtype)