Skip to content

vllm_omni.diffusion.models.dreamzero.causal_wan_model

CausalWanModel — 40-layer DiT with causal attention and KV cache.

Key differences from WanTransformer3DModel: - Causal self-attention (new frames only see history) - KV cache for streaming inference - Action/state token support (appended after video tokens) - Extended RoPE with action/state-specific frequencies - Inference-only forward with KV cache

WAN_CROSSATTENTION_CLASSES module-attribute

WAN_CROSSATTENTION_CLASSES = {
    "t2v_cross_attn": WanT2VCrossAttention,
    "i2v_cross_attn": WanI2VCrossAttention,
}

CausalHead

Bases: Module

Output norm + linear with 2-param modulation. Runs once per step (not TP-critical), uses nn.Linear.

dim instance-attribute

dim = dim

head instance-attribute

head = Linear(dim, out_channels)

modulation instance-attribute

modulation = Parameter(randn(1, 2, dim) / dim ** 0.5)

norm instance-attribute

norm = WanLayerNorm(dim, eps)

out_dim instance-attribute

out_dim = out_dim

patch_size instance-attribute

patch_size = patch_size

forward

forward(x: Tensor, e: Tensor) -> Tensor

Parameters:

Name Type Description Default
x Tensor

[B, L1, C]

required
e Tensor

[B, F, 1, C] (time embedding, unsqueezed)

required

CausalWanAttentionBlock

Bases: Module

Transformer block: self-attn + cross-attn + FFN with 6-param modulation.

cross_attn instance-attribute

cross_attn = WAN_CROSSATTENTION_CLASSES[cross_attn_type](
    dim, num_heads, (-1, -1), qk_norm, eps
)

ffn instance-attribute

ffn = Sequential(
    ColumnParallelLinear(
        dim,
        ffn_dim,
        bias=True,
        gather_output=False,
        return_bias=False,
    ),
    GELU(approximate="tanh"),
    RowParallelLinear(
        ffn_dim,
        dim,
        bias=True,
        input_is_parallel=True,
        return_bias=False,
    ),
)

modulation instance-attribute

modulation = Parameter(randn(1, 6, dim) / dim ** 0.5)

norm1 instance-attribute

norm1 = WanLayerNorm(dim, eps)

norm2 instance-attribute

norm2 = WanLayerNorm(dim, eps)

norm3 instance-attribute

norm3 = (
    WanLayerNorm(dim, eps, elementwise_affine=True)
    if cross_attn_norm
    else Identity()
)

self_attn instance-attribute

self_attn = CausalWanSelfAttention(
    dim=dim,
    num_heads=num_heads,
    frame_seqlen=frame_seqlen,
    local_attn_size=local_attn_size,
    sink_size=sink_size,
    num_frame_per_block=num_frame_per_block,
    qk_norm=qk_norm,
    eps=eps,
    num_action_per_block=num_action_per_block,
    num_state_per_block=num_state_per_block,
)

forward

forward(
    x: Tensor,
    e: Tensor,
    freqs: Tensor,
    freqs_action: Tensor,
    freqs_state: Tensor,
    context: Tensor,
    action_register_length: int | None = None,
    kv_cache: Tensor | None = None,
    crossattn_cache: dict | None = None,
    current_start_frame: int = 0,
    is_tf: bool = True,
) -> tuple[Tensor, Tensor | None]

CausalWanModel

Bases: Module

Causal video diffusion transformer for DreamZero.

Architecture (14B): 40 layers, dim=5120, heads=40, ffn=13824

action_decoder instance-attribute

action_decoder = CategorySpecificMLP(
    num_categories=max_num_embodiments_local,
    input_dim=dim,
    hidden_dim=hidden_size,
    output_dim=action_dim,
)

action_dim instance-attribute

action_dim = action_dim

action_encoder instance-attribute

action_encoder = MultiEmbodimentActionEncoder(
    action_dim=action_dim,
    hidden_size=dim,
    num_embodiments=max_num_embodiments_local,
)

blocks instance-attribute

blocks = ModuleList(
    [
        (
            CausalWanAttentionBlock(
                cross_attn_type,
                dim,
                ffn_dim,
                num_heads,
                frame_seqlen,
                local_attn_size,
                sink_size,
                num_frame_per_block,
                qk_norm,
                cross_attn_norm,
                eps,
                num_action_per_block,
                num_state_per_block,
            )
        )
        for _ in (range(num_layers))
    ]
)

dim instance-attribute

dim = dim

frame_seqlen instance-attribute

frame_seqlen = frame_seqlen

freq_dim instance-attribute

freq_dim = freq_dim

freqs instance-attribute

freqs = [
    rope_params(1024, d - 4 * (d // 6)),
    rope_params(1024, 2 * (d // 6)),
    rope_params(1024, 2 * (d // 6)),
]

freqs_action instance-attribute

freqs_action = rope_params(1024 * 10, d)

freqs_state instance-attribute

freqs_state = rope_params(1024, d)

head instance-attribute

head = CausalHead(dim, out_dim, patch_size, eps)

img_emb instance-attribute

img_emb = MLPProj(1280, dim)

local_attn_size instance-attribute

local_attn_size = (
    max_chunk_size * num_frame_per_block + 1
    if max_chunk_size != -1
    else -1
)

model_type instance-attribute

model_type = model_type

num_action_per_block instance-attribute

num_action_per_block = num_action_per_block

num_frame_per_block instance-attribute

num_frame_per_block = num_frame_per_block

num_heads instance-attribute

num_heads = num_heads

num_layers instance-attribute

num_layers = num_layers

num_state_per_block instance-attribute

num_state_per_block = num_state_per_block

out_dim instance-attribute

out_dim = out_dim

patch_embedding instance-attribute

patch_embedding = Conv3dLayer(
    in_dim, dim, kernel_size=patch_size, stride=patch_size
)

patch_size instance-attribute

patch_size = patch_size

state_encoder instance-attribute

state_encoder = CategorySpecificMLP(
    num_categories=max_num_embodiments_local,
    input_dim=max_state_dim,
    hidden_dim=hidden_size,
    output_dim=dim,
)

text_embedding instance-attribute

text_embedding = Sequential(
    Linear(text_dim, dim),
    GELU(approximate="tanh"),
    Linear(dim, dim),
)

text_len instance-attribute

text_len = text_len

time_embedding instance-attribute

time_embedding = Sequential(
    Linear(freq_dim, dim), SiLU(), Linear(dim, dim)
)

time_projection instance-attribute

time_projection = Sequential(SiLU(), Linear(dim, dim * 6))

forward

forward(*args: Any, **kwargs: Any)

Inference only. Requires kv_cache.

init_weights

init_weights() -> None

Initialize parameters.

unpatchify

unpatchify(x: Tensor, grid_size: Tensor) -> Tensor

Reconstruct video from patch embeddings.

CausalWanSelfAttention

Bases: Module

Causal self-attention with KV cache + action/state tokens.

attn instance-attribute

attn = Attention(
    tp_num_heads,
    head_dim,
    causal=False,
    softmax_scale=head_dim**-0.5,
    skip_sequence_parallel=True,
)

dim instance-attribute

dim = dim

frame_seqlen instance-attribute

frame_seqlen = frame_seqlen

head_dim instance-attribute

head_dim = dim // num_heads

k instance-attribute

k = ColumnParallelLinear(
    dim,
    dim,
    bias=True,
    gather_output=False,
    return_bias=False,
)

local_attn_size instance-attribute

local_attn_size = local_attn_size

max_attention_size instance-attribute

max_attention_size = (
    21 * frame_seqlen
    if local_attn_size == -1
    else local_attn_size * frame_seqlen
)

norm_k instance-attribute

norm_k = (
    DistributedRMSNorm(tp_inner_dim, eps=eps)
    if qk_norm
    else Identity()
)

norm_q instance-attribute

norm_q = (
    DistributedRMSNorm(tp_inner_dim, eps=eps)
    if qk_norm
    else Identity()
)

num_action_per_block instance-attribute

num_action_per_block = num_action_per_block

num_frame_per_block instance-attribute

num_frame_per_block = num_frame_per_block

num_heads instance-attribute

num_heads = num_heads

num_state_per_block instance-attribute

num_state_per_block = num_state_per_block

o instance-attribute

o = RowParallelLinear(
    dim,
    dim,
    bias=True,
    input_is_parallel=True,
    return_bias=False,
)

q instance-attribute

q = ColumnParallelLinear(
    dim,
    dim,
    bias=True,
    gather_output=False,
    return_bias=False,
)

tp_inner_dim instance-attribute

tp_inner_dim = tp_num_heads * head_dim

tp_num_heads instance-attribute

tp_num_heads = num_heads // tp_size

v instance-attribute

v = ColumnParallelLinear(
    dim,
    dim,
    bias=True,
    gather_output=False,
    return_bias=False,
)

forward

forward(
    x: Tensor,
    freqs: Tensor,
    freqs_action: Tensor,
    freqs_state: Tensor,
    action_register_length: int | None,
    kv_cache: Tensor | None = None,
    current_start_frame: int = 0,
    is_tf: bool = True,
) -> tuple[Tensor, Tensor | None]

Inference-only forward (KV cache path).

DistributedRMSNorm

Bases: Module

RMSNorm that computes global RMS across tensor parallel ranks.

eps instance-attribute

eps = eps

weight instance-attribute

weight = Parameter(ones(hidden_size))

forward

forward(x: Tensor) -> Tensor

weight_loader

weight_loader(param: Tensor, loaded_weight: Tensor) -> None

MLPProj

Bases: Module

CLIP feature projection for i2v. Uses ColumnParallelLinear + RowParallelLinear (Qwen3_VisionMLP pattern).

act instance-attribute

act = GELU()

fc1 instance-attribute

fc1 = ColumnParallelLinear(
    in_dim, in_dim, bias=True, return_bias=False
)

fc2 instance-attribute

fc2 = RowParallelLinear(
    in_dim, out_dim, bias=True, return_bias=False
)

norm1 instance-attribute

norm1 = LayerNorm(in_dim)

norm2 instance-attribute

norm2 = LayerNorm(out_dim)

forward

forward(image_embeds: Tensor) -> Tensor

WanI2VCrossAttention

Bases: Module

Image-to-video cross-attention (splits first 257 image tokens). Uses vllm-omni Attention for FlashAttn backend.

attn instance-attribute

attn = Attention(
    tp_num_heads,
    head_dim,
    causal=False,
    softmax_scale=head_dim**-0.5,
    skip_sequence_parallel=True,
)

dim instance-attribute

dim = dim

head_dim instance-attribute

head_dim = dim // num_heads

k instance-attribute

k = ColumnParallelLinear(
    dim,
    dim,
    bias=True,
    gather_output=False,
    return_bias=False,
)

k_img instance-attribute

k_img = ColumnParallelLinear(
    dim,
    dim,
    bias=True,
    gather_output=False,
    return_bias=False,
)

norm_k instance-attribute

norm_k = (
    DistributedRMSNorm(tp_inner_dim, eps=eps)
    if qk_norm
    else Identity()
)

norm_k_img instance-attribute

norm_k_img = (
    DistributedRMSNorm(tp_inner_dim, eps=eps)
    if qk_norm
    else Identity()
)

norm_q instance-attribute

norm_q = (
    DistributedRMSNorm(tp_inner_dim, eps=eps)
    if qk_norm
    else Identity()
)

num_heads instance-attribute

num_heads = num_heads

o instance-attribute

o = RowParallelLinear(
    dim,
    dim,
    bias=True,
    input_is_parallel=True,
    return_bias=False,
)

q instance-attribute

q = ColumnParallelLinear(
    dim,
    dim,
    bias=True,
    gather_output=False,
    return_bias=False,
)

tp_inner_dim instance-attribute

tp_inner_dim = tp_num_heads * head_dim

tp_num_heads instance-attribute

tp_num_heads = num_heads // tp_size

v instance-attribute

v = ColumnParallelLinear(
    dim,
    dim,
    bias=True,
    gather_output=False,
    return_bias=False,
)

v_img instance-attribute

v_img = ColumnParallelLinear(
    dim,
    dim,
    bias=True,
    gather_output=False,
    return_bias=False,
)

forward

forward(
    x: Tensor,
    context: Tensor,
    context_lens: Tensor | None = None,
    crossattn_cache: dict | None = None,
) -> Tensor

WanLayerNorm

Bases: LayerNorm

LayerNorm wrapper used by DreamZero blocks.

WanT2VCrossAttention

Bases: Module

Text-to-video cross-attention. Uses vllm-omni Attention for FlashAttn backend.

attn instance-attribute

attn = Attention(
    tp_num_heads,
    head_dim,
    causal=False,
    softmax_scale=head_dim**-0.5,
    skip_sequence_parallel=True,
)

dim instance-attribute

dim = dim

head_dim instance-attribute

head_dim = dim // num_heads

k instance-attribute

k = ColumnParallelLinear(
    dim,
    dim,
    bias=True,
    gather_output=False,
    return_bias=False,
)

norm_k instance-attribute

norm_k = (
    DistributedRMSNorm(tp_inner_dim, eps=eps)
    if qk_norm
    else Identity()
)

norm_q instance-attribute

norm_q = (
    DistributedRMSNorm(tp_inner_dim, eps=eps)
    if qk_norm
    else Identity()
)

num_heads instance-attribute

num_heads = num_heads

o instance-attribute

o = RowParallelLinear(
    dim,
    dim,
    bias=True,
    input_is_parallel=True,
    return_bias=False,
)

q instance-attribute

q = ColumnParallelLinear(
    dim,
    dim,
    bias=True,
    gather_output=False,
    return_bias=False,
)

tp_inner_dim instance-attribute

tp_inner_dim = tp_num_heads * head_dim

tp_num_heads instance-attribute

tp_num_heads = num_heads // tp_size

v instance-attribute

v = ColumnParallelLinear(
    dim,
    dim,
    bias=True,
    gather_output=False,
    return_bias=False,
)

forward

forward(
    x: Tensor,
    context: Tensor,
    context_lens: Tensor | None = None,
    crossattn_cache: dict | None = None,
) -> Tensor

causal_rope_action_apply

causal_rope_action_apply(
    x: Tensor,
    freqs: Tensor,
    freqs_action: Tensor,
    freqs_state: Tensor,
    action_register_length: int | None,
    num_action_per_block: int,
    num_state_per_block: int,
    action_state_index: int,
) -> Tensor

RoPE for single inference step (causal / KV-cache mode).

rope_action_apply

rope_action_apply(
    x: Tensor,
    freqs: Tensor,
    freqs_action: Tensor,
    freqs_state: Tensor,
    action_register_length: int | None,
    num_action_per_block: int = 32,
    num_state_per_block: int = 1,
) -> Tensor

RoPE with action/state frequency tables for multi-step sequences.

rope_apply

rope_apply(x: Tensor, freqs: Tensor) -> Tensor

Apply RoPE to x using precomputed complex freqs.

rope_params

rope_params(max_seq_len: int, dim: int) -> Tensor

Precompute complex-valued RoPE frequencies (polar form). Returns: complex tensor [max_seq_len, dim // 2]

sinusoidal_embedding_1d

sinusoidal_embedding_1d(
    dim: int, position: Tensor
) -> Tensor

Sinusoidal positional embedding for timesteps.