Skip to content

vllm_omni.diffusion.models.dreamid_omni.wan2_2

WAN_CROSSATTENTION_CLASSES module-attribute

WAN_CROSSATTENTION_CLASSES = {
    "t2v_cross_attn": WanT2VCrossAttention,
    "i2v_cross_attn": WanI2VCrossAttention,
}

DistributedRMSNorm

Bases: Module

RMSNorm that computes global RMS across tensor parallel ranks.

Mirrors vllm_omni/diffusion/models/wan2_2/wan2_2_transformer.py::DistributedRMSNorm

eps instance-attribute

eps = eps

weight instance-attribute

weight = Parameter(ones(hidden_size))

forward

forward(x: Tensor) -> Tensor

WanAttentionBlock

Bases: Module

cross_attn instance-attribute

cross_attn = WanI2VCrossAttention(
    dim,
    num_heads,
    (-1, -1),
    qk_norm,
    eps,
    additional_emb_length,
)

cross_attn_norm instance-attribute

cross_attn_norm = cross_attn_norm

dim instance-attribute

dim = dim

eps instance-attribute

eps = eps

ffn instance-attribute

ffn = Sequential(
    ColumnParallelLinear(
        dim,
        ffn_dim,
        bias=True,
        gather_output=False,
        return_bias=False,
        quant_config=quant_config,
        prefix=f"{prefix}.ffn.0" if prefix else "ffn.0",
    ),
    GELU(approximate="tanh"),
    RowParallelLinear(
        ffn_dim,
        dim,
        bias=True,
        input_is_parallel=True,
        return_bias=False,
        quant_config=quant_config,
        prefix=f"{prefix}.ffn.2" if prefix else "ffn.2",
    ),
)

ffn_dim instance-attribute

ffn_dim = ffn_dim

modulation instance-attribute

modulation = ModulationAdd(dim, 6)

norm1 instance-attribute

norm1 = WanLayerNorm(dim, eps)

norm2 instance-attribute

norm2 = WanLayerNorm(dim, eps)

norm3 instance-attribute

norm3 = (
    WanLayerNorm(dim, eps, elementwise_affine=True)
    if cross_attn_norm
    else Identity()
)

num_heads instance-attribute

num_heads = num_heads

qk_norm instance-attribute

qk_norm = qk_norm

self_attn instance-attribute

self_attn = WanSelfAttention(
    dim,
    num_heads,
    window_size,
    qk_norm,
    eps,
    quant_config=quant_config,
    prefix=f"{prefix}.self_attn" if prefix else "self_attn",
)

window_size instance-attribute

window_size = window_size

forward

forward(
    x, e, seq_lens, grid_sizes, freqs, context, context_lens
)

Parameters:

Name Type Description Default
x Tensor

Shape [B, L, C]

required
e Tensor

Shape [B, L1, 6, C]

required
seq_lens Tensor

Shape [B], length of each sequence in batch

required
grid_sizes Tensor

Shape [B, 3], the second dimension contains (F, H, W)

required
freqs Tensor

Rope freqs, shape [1024, C / num_heads / 2]

required

WanI2VCrossAttention

Bases: Module

forward

forward(*args, **kwargs)

WanModel

Bases: ModelMixin, ConfigMixin

Wan diffusion backbone supporting both text-to-video and image-to-video, text-to-audio.

blocks instance-attribute

blocks = ModuleList(
    [
        (
            WanAttentionBlock(
                cross_attn_type,
                dim,
                ffn_dim,
                num_heads,
                window_size,
                qk_norm,
                cross_attn_norm,
                eps,
                additional_emb_length,
                quant_config=quant_config,
                prefix=f"{prefix}.blocks.{layer_idx}"
                if prefix
                else f"blocks.{layer_idx}",
            )
        )
        for layer_idx in (range(num_layers))
    ]
)

cross_attn_norm instance-attribute

cross_attn_norm = cross_attn_norm

dim instance-attribute

dim = dim

eps instance-attribute

eps = eps

ffn_dim instance-attribute

ffn_dim = ffn_dim

freq_dim instance-attribute

freq_dim = freq_dim

head instance-attribute

head = Head(dim, out_dim, patch_size, eps)

img_emb instance-attribute

img_emb = MLPProj(additional_emb_dim, dim)

in_dim instance-attribute

in_dim = in_dim

is_audio_type instance-attribute

is_audio_type = is_audio_type

is_video_type instance-attribute

is_video_type = is_video_type

model_type instance-attribute

model_type = model_type

num_heads instance-attribute

num_heads = num_heads

num_layers instance-attribute

num_layers = num_layers

out_dim instance-attribute

out_dim = out_dim

patch_embedding instance-attribute

patch_embedding = Sequential(
    ChannelLastConv1d(
        in_dim, dim, kernel_size=7, padding=3
    ),
    SiLU(),
    ConvMLP(dim, dim * 4, kernel_size=7, padding=3),
)

patch_size instance-attribute

patch_size = patch_size

qk_norm instance-attribute

qk_norm = qk_norm

temporal_rope_scaling_factor instance-attribute

temporal_rope_scaling_factor = temporal_rope_scaling_factor

text_dim instance-attribute

text_dim = text_dim

text_embedding instance-attribute

text_embedding = Sequential(
    Linear(text_dim, dim),
    GELU(approximate="tanh"),
    Linear(dim, dim),
)

text_len instance-attribute

text_len = text_len

time_embedding instance-attribute

time_embedding = Sequential(
    Linear(freq_dim, dim), SiLU(), Linear(dim, dim)
)

time_projection instance-attribute

time_projection = Sequential(SiLU(), Linear(dim, dim * 6))

window_size instance-attribute

window_size = window_size

forward

forward(*args, **kwargs)

post_transformer_block_out

post_transformer_block_out(x, grid_sizes, e)

prepare_transformer_block_kwargs

prepare_transformer_block_kwargs(
    x,
    t,
    context,
    seq_len,
    ref_lengths=None,
    freqs_scaling=None,
)

set_rope_params

set_rope_params()

unpatchify

unpatchify(x, grid_sizes) -> list[Tensor]

Reconstruct video tensors from patch embeddings.

Parameters:

Name Type Description Default
x List[Tensor]

List of patchified features, each with shape [L, C_out * prod(patch_size)]

required
grid_sizes Tensor

Original spatial-temporal grid dimensions before patching, shape [B, 3] (3 dimensions correspond to F_patches, H_patches, W_patches)

required

Returns:

Type Description
list[Tensor]

List[Tensor]: Reconstructed video tensors with shape [C_out, F, H / 8, W / 8]

WanSelfAttention

Bases: Module

Optimized self-attention module using vLLM layers.

attn instance-attribute

attn = Attention(
    num_heads=num_heads,
    head_size=head_dim,
    num_kv_heads=num_kv_heads,
    softmax_scale=1.0 / head_dim**0.5,
    causal=False,
    prefix=prefix,
)

dim instance-attribute

dim = dim

eps instance-attribute

eps = eps

head_dim instance-attribute

head_dim = dim // num_heads

norm_k instance-attribute

norm_k = (
    DistributedRMSNorm(tp_kv_dim, eps=eps)
    if qk_norm
    else Identity()
)

norm_q instance-attribute

norm_q = (
    DistributedRMSNorm(tp_inner_dim, eps=eps)
    if qk_norm
    else Identity()
)

num_heads instance-attribute

num_heads = num_heads

num_kv_heads instance-attribute

num_kv_heads = num_kv_heads

o instance-attribute

o = RowParallelLinear(
    dim,
    dim,
    bias=True,
    input_is_parallel=True,
    return_bias=False,
    quant_config=quant_config,
    prefix=f"{prefix}.o" if prefix else "o",
)

qk_norm instance-attribute

qk_norm = qk_norm

to_qkv instance-attribute

to_qkv = QKVParallelLinear(
    hidden_size=dim,
    head_size=head_dim,
    total_num_heads=num_heads,
    bias=True,
    quant_config=quant_config,
    prefix=f"{prefix}.to_qkv" if prefix else "to_qkv",
)

tp_inner_dim instance-attribute

tp_inner_dim = num_heads * head_dim

tp_kv_dim instance-attribute

tp_kv_dim = num_kv_heads * head_dim

window_size instance-attribute

window_size = window_size

forward

forward(
    x,
    seq_lens,
    grid_sizes,
    freqs,
    ref_lengths=None,
    freqs_scaling=None,
)

Parameters:

Name Type Description Default
x Tensor

Shape [B, L, C]

required
seq_lens Tensor

Shape [B]

required
grid_sizes Tensor

Shape [B, 3], the second dimension contains (F, H, W)

required
freqs Tensor

Rope freqs, shape [1024, C / num_heads / 2]

required
ref_lengths Tensor

Shape [B]

None

qkv_fn

qkv_fn(x)

WanT2VCrossAttention

Bases: Module

Text -> latent cross-attention.

attn instance-attribute

attn = Attention(
    num_heads=num_heads,
    head_size=head_dim,
    num_kv_heads=num_kv_heads,
    softmax_scale=1.0 / head_dim**0.5,
    causal=False,
    prefix=prefix,
    disable_kv_quant=True,
)

dim instance-attribute

dim = dim

eps instance-attribute

eps = eps

head_dim instance-attribute

head_dim = dim // num_heads

k instance-attribute

k = ColumnParallelLinear(
    dim,
    dim,
    bias=True,
    gather_output=False,
    return_bias=False,
    quant_config=quant_config,
    prefix=f"{prefix}.k" if prefix else "k",
)

norm_k instance-attribute

norm_k = (
    DistributedRMSNorm(tp_inner_dim, eps=eps)
    if qk_norm
    else Identity()
)

norm_q instance-attribute

norm_q = (
    DistributedRMSNorm(tp_inner_dim, eps=eps)
    if qk_norm
    else Identity()
)

num_heads instance-attribute

num_heads = num_heads // tp_size

num_kv_heads instance-attribute

num_kv_heads = num_heads

o instance-attribute

o = RowParallelLinear(
    dim,
    dim,
    bias=True,
    input_is_parallel=True,
    return_bias=False,
    quant_config=quant_config,
    prefix=f"{prefix}.o" if prefix else "o",
)

q instance-attribute

q = ColumnParallelLinear(
    dim,
    dim,
    bias=True,
    gather_output=False,
    return_bias=False,
    quant_config=quant_config,
    prefix=f"{prefix}.q" if prefix else "q",
)

qk_norm instance-attribute

qk_norm = qk_norm

tp_inner_dim instance-attribute

tp_inner_dim = num_heads * head_dim

v instance-attribute

v = ColumnParallelLinear(
    dim,
    dim,
    bias=True,
    gather_output=False,
    return_bias=False,
    quant_config=quant_config,
    prefix=f"{prefix}.v" if prefix else "v",
)

window_size instance-attribute

window_size = window_size

forward

forward(x, context, context_lens)

Parameters:

Name Type Description Default
x Tensor

Shape [B, L1, C]

required
context Tensor

Shape [B, L2, C]

required
context_lens Tensor

Shape [B]

required

qkv_fn

qkv_fn(x, context)