vllm_omni.diffusion.models.dreamid_omni.wan2_2 ¶
WAN_CROSSATTENTION_CLASSES module-attribute ¶
WAN_CROSSATTENTION_CLASSES = {
"t2v_cross_attn": WanT2VCrossAttention,
"i2v_cross_attn": WanI2VCrossAttention,
}
DistributedRMSNorm ¶
Bases: Module
RMSNorm that computes global RMS across tensor parallel ranks.
Mirrors vllm_omni/diffusion/models/wan2_2/wan2_2_transformer.py::DistributedRMSNorm
WanAttentionBlock ¶
Bases: Module
cross_attn instance-attribute ¶
cross_attn = WanI2VCrossAttention(
dim,
num_heads,
(-1, -1),
qk_norm,
eps,
additional_emb_length,
)
ffn instance-attribute ¶
ffn = Sequential(
ColumnParallelLinear(
dim,
ffn_dim,
bias=True,
gather_output=False,
return_bias=False,
quant_config=quant_config,
prefix=f"{prefix}.ffn.0" if prefix else "ffn.0",
),
GELU(approximate="tanh"),
RowParallelLinear(
ffn_dim,
dim,
bias=True,
input_is_parallel=True,
return_bias=False,
quant_config=quant_config,
prefix=f"{prefix}.ffn.2" if prefix else "ffn.2",
),
)
norm3 instance-attribute ¶
self_attn instance-attribute ¶
self_attn = WanSelfAttention(
dim,
num_heads,
window_size,
qk_norm,
eps,
quant_config=quant_config,
prefix=f"{prefix}.self_attn" if prefix else "self_attn",
)
forward ¶
Parameters:
| Name | Type | Description | Default |
|---|---|---|---|
x | Tensor | Shape [B, L, C] | required |
e | Tensor | Shape [B, L1, 6, C] | required |
seq_lens | Tensor | Shape [B], length of each sequence in batch | required |
grid_sizes | Tensor | Shape [B, 3], the second dimension contains (F, H, W) | required |
freqs | Tensor | Rope freqs, shape [1024, C / num_heads / 2] | required |
WanModel ¶
Bases: ModelMixin, ConfigMixin
Wan diffusion backbone supporting both text-to-video and image-to-video, text-to-audio.
blocks instance-attribute ¶
blocks = ModuleList(
[
(
WanAttentionBlock(
cross_attn_type,
dim,
ffn_dim,
num_heads,
window_size,
qk_norm,
cross_attn_norm,
eps,
additional_emb_length,
quant_config=quant_config,
prefix=f"{prefix}.blocks.{layer_idx}"
if prefix
else f"blocks.{layer_idx}",
)
)
for layer_idx in (range(num_layers))
]
)
patch_embedding instance-attribute ¶
patch_embedding = Sequential(
ChannelLastConv1d(
in_dim, dim, kernel_size=7, padding=3
),
SiLU(),
ConvMLP(dim, dim * 4, kernel_size=7, padding=3),
)
temporal_rope_scaling_factor instance-attribute ¶
text_embedding instance-attribute ¶
time_embedding instance-attribute ¶
prepare_transformer_block_kwargs ¶
unpatchify ¶
unpatchify(x, grid_sizes) -> list[Tensor]
Reconstruct video tensors from patch embeddings.
Parameters:
| Name | Type | Description | Default |
|---|---|---|---|
x | List[Tensor] | List of patchified features, each with shape [L, C_out * prod(patch_size)] | required |
grid_sizes | Tensor | Original spatial-temporal grid dimensions before patching, shape [B, 3] (3 dimensions correspond to F_patches, H_patches, W_patches) | required |
Returns:
| Type | Description |
|---|---|
list[Tensor] | List[Tensor]: Reconstructed video tensors with shape [C_out, F, H / 8, W / 8] |
WanSelfAttention ¶
Bases: Module
Optimized self-attention module using vLLM layers.
attn instance-attribute ¶
attn = Attention(
num_heads=num_heads,
head_size=head_dim,
num_kv_heads=num_kv_heads,
softmax_scale=1.0 / head_dim**0.5,
causal=False,
prefix=prefix,
)
norm_k instance-attribute ¶
norm_k = (
DistributedRMSNorm(tp_kv_dim, eps=eps)
if qk_norm
else Identity()
)
norm_q instance-attribute ¶
norm_q = (
DistributedRMSNorm(tp_inner_dim, eps=eps)
if qk_norm
else Identity()
)
o instance-attribute ¶
o = RowParallelLinear(
dim,
dim,
bias=True,
input_is_parallel=True,
return_bias=False,
quant_config=quant_config,
prefix=f"{prefix}.o" if prefix else "o",
)
to_qkv instance-attribute ¶
to_qkv = QKVParallelLinear(
hidden_size=dim,
head_size=head_dim,
total_num_heads=num_heads,
bias=True,
quant_config=quant_config,
prefix=f"{prefix}.to_qkv" if prefix else "to_qkv",
)
forward ¶
Parameters:
| Name | Type | Description | Default |
|---|---|---|---|
x | Tensor | Shape [B, L, C] | required |
seq_lens | Tensor | Shape [B] | required |
grid_sizes | Tensor | Shape [B, 3], the second dimension contains (F, H, W) | required |
freqs | Tensor | Rope freqs, shape [1024, C / num_heads / 2] | required |
ref_lengths | Tensor | Shape [B] | None |
WanT2VCrossAttention ¶
Bases: Module
Text -> latent cross-attention.
attn instance-attribute ¶
attn = Attention(
num_heads=num_heads,
head_size=head_dim,
num_kv_heads=num_kv_heads,
softmax_scale=1.0 / head_dim**0.5,
causal=False,
prefix=prefix,
disable_kv_quant=True,
)
k instance-attribute ¶
k = ColumnParallelLinear(
dim,
dim,
bias=True,
gather_output=False,
return_bias=False,
quant_config=quant_config,
prefix=f"{prefix}.k" if prefix else "k",
)
norm_k instance-attribute ¶
norm_k = (
DistributedRMSNorm(tp_inner_dim, eps=eps)
if qk_norm
else Identity()
)
norm_q instance-attribute ¶
norm_q = (
DistributedRMSNorm(tp_inner_dim, eps=eps)
if qk_norm
else Identity()
)
o instance-attribute ¶
o = RowParallelLinear(
dim,
dim,
bias=True,
input_is_parallel=True,
return_bias=False,
quant_config=quant_config,
prefix=f"{prefix}.o" if prefix else "o",
)
q instance-attribute ¶
q = ColumnParallelLinear(
dim,
dim,
bias=True,
gather_output=False,
return_bias=False,
quant_config=quant_config,
prefix=f"{prefix}.q" if prefix else "q",
)
v instance-attribute ¶
v = ColumnParallelLinear(
dim,
dim,
bias=True,
gather_output=False,
return_bias=False,
quant_config=quant_config,
prefix=f"{prefix}.v" if prefix else "v",
)
forward ¶
Parameters:
| Name | Type | Description | Default |
|---|---|---|---|
x | Tensor | Shape [B, L1, C] | required |
context | Tensor | Shape [B, L2, C] | required |
context_lens | Tensor | Shape [B] | required |