vllm_omni.diffusion.models.dreamzero.causal_wan_model ¶
CausalWanModel — 40-layer DiT with causal attention and KV cache.
Key differences from WanTransformer3DModel: - Causal self-attention (new frames only see history) - KV cache for streaming inference - Action/state token support (appended after video tokens) - Extended RoPE with action/state-specific frequencies - Inference-only forward with KV cache
WAN_CROSSATTENTION_CLASSES module-attribute ¶
WAN_CROSSATTENTION_CLASSES = {
"t2v_cross_attn": WanT2VCrossAttention,
"i2v_cross_attn": WanI2VCrossAttention,
}
CausalHead ¶
Bases: Module
Output norm + linear with 2-param modulation. Runs once per step (not TP-critical), uses nn.Linear.
forward ¶
Parameters:
| Name | Type | Description | Default |
|---|---|---|---|
x | Tensor | [B, L1, C] | required |
e | Tensor | [B, F, 1, C] (time embedding, unsqueezed) | required |
CausalWanAttentionBlock ¶
Bases: Module
Transformer block: self-attn + cross-attn + FFN with 6-param modulation.
cross_attn instance-attribute ¶
cross_attn = WAN_CROSSATTENTION_CLASSES[cross_attn_type](
dim, num_heads, (-1, -1), qk_norm, eps
)
ffn instance-attribute ¶
ffn = Sequential(
ColumnParallelLinear(
dim,
ffn_dim,
bias=True,
gather_output=False,
return_bias=False,
),
GELU(approximate="tanh"),
RowParallelLinear(
ffn_dim,
dim,
bias=True,
input_is_parallel=True,
return_bias=False,
),
)
norm3 instance-attribute ¶
norm3 = (
WanLayerNorm(dim, eps, elementwise_affine=True)
if cross_attn_norm
else Identity()
)
self_attn instance-attribute ¶
self_attn = CausalWanSelfAttention(
dim=dim,
num_heads=num_heads,
frame_seqlen=frame_seqlen,
local_attn_size=local_attn_size,
sink_size=sink_size,
num_frame_per_block=num_frame_per_block,
qk_norm=qk_norm,
eps=eps,
num_action_per_block=num_action_per_block,
num_state_per_block=num_state_per_block,
)
forward ¶
forward(
x: Tensor,
e: Tensor,
freqs: Tensor,
freqs_action: Tensor,
freqs_state: Tensor,
context: Tensor,
action_register_length: int | None = None,
kv_cache: Tensor | None = None,
crossattn_cache: dict | None = None,
current_start_frame: int = 0,
is_tf: bool = True,
) -> tuple[Tensor, Tensor | None]
CausalWanModel ¶
Bases: Module
Causal video diffusion transformer for DreamZero.
Architecture (14B): 40 layers, dim=5120, heads=40, ffn=13824
action_decoder instance-attribute ¶
action_decoder = CategorySpecificMLP(
num_categories=max_num_embodiments_local,
input_dim=dim,
hidden_dim=hidden_size,
output_dim=action_dim,
)
action_encoder instance-attribute ¶
action_encoder = MultiEmbodimentActionEncoder(
action_dim=action_dim,
hidden_size=dim,
num_embodiments=max_num_embodiments_local,
)
blocks instance-attribute ¶
blocks = ModuleList(
[
(
CausalWanAttentionBlock(
cross_attn_type,
dim,
ffn_dim,
num_heads,
frame_seqlen,
local_attn_size,
sink_size,
num_frame_per_block,
qk_norm,
cross_attn_norm,
eps,
num_action_per_block,
num_state_per_block,
)
)
for _ in (range(num_layers))
]
)
freqs instance-attribute ¶
freqs = [
rope_params(1024, d - 4 * (d // 6)),
rope_params(1024, 2 * (d // 6)),
rope_params(1024, 2 * (d // 6)),
]
local_attn_size instance-attribute ¶
patch_embedding instance-attribute ¶
state_encoder instance-attribute ¶
state_encoder = CategorySpecificMLP(
num_categories=max_num_embodiments_local,
input_dim=max_state_dim,
hidden_dim=hidden_size,
output_dim=dim,
)
text_embedding instance-attribute ¶
time_embedding instance-attribute ¶
unpatchify ¶
Reconstruct video from patch embeddings.
CausalWanSelfAttention ¶
Bases: Module
Causal self-attention with KV cache + action/state tokens.
attn instance-attribute ¶
attn = Attention(
tp_num_heads,
head_dim,
causal=False,
softmax_scale=head_dim**-0.5,
skip_sequence_parallel=True,
)
k instance-attribute ¶
max_attention_size instance-attribute ¶
max_attention_size = (
21 * frame_seqlen
if local_attn_size == -1
else local_attn_size * frame_seqlen
)
norm_k instance-attribute ¶
norm_k = (
DistributedRMSNorm(tp_inner_dim, eps=eps)
if qk_norm
else Identity()
)
norm_q instance-attribute ¶
norm_q = (
DistributedRMSNorm(tp_inner_dim, eps=eps)
if qk_norm
else Identity()
)
o instance-attribute ¶
q instance-attribute ¶
v instance-attribute ¶
DistributedRMSNorm ¶
MLPProj ¶
Bases: Module
CLIP feature projection for i2v. Uses ColumnParallelLinear + RowParallelLinear (Qwen3_VisionMLP pattern).
WanI2VCrossAttention ¶
Bases: Module
Image-to-video cross-attention (splits first 257 image tokens). Uses vllm-omni Attention for FlashAttn backend.
attn instance-attribute ¶
attn = Attention(
tp_num_heads,
head_dim,
causal=False,
softmax_scale=head_dim**-0.5,
skip_sequence_parallel=True,
)
k instance-attribute ¶
k_img instance-attribute ¶
norm_k instance-attribute ¶
norm_k = (
DistributedRMSNorm(tp_inner_dim, eps=eps)
if qk_norm
else Identity()
)
norm_k_img instance-attribute ¶
norm_k_img = (
DistributedRMSNorm(tp_inner_dim, eps=eps)
if qk_norm
else Identity()
)
norm_q instance-attribute ¶
norm_q = (
DistributedRMSNorm(tp_inner_dim, eps=eps)
if qk_norm
else Identity()
)
o instance-attribute ¶
q instance-attribute ¶
v instance-attribute ¶
v_img instance-attribute ¶
WanLayerNorm ¶
Bases: LayerNorm
LayerNorm wrapper used by DreamZero blocks.
WanT2VCrossAttention ¶
Bases: Module
Text-to-video cross-attention. Uses vllm-omni Attention for FlashAttn backend.
attn instance-attribute ¶
attn = Attention(
tp_num_heads,
head_dim,
causal=False,
softmax_scale=head_dim**-0.5,
skip_sequence_parallel=True,
)
k instance-attribute ¶
norm_k instance-attribute ¶
norm_k = (
DistributedRMSNorm(tp_inner_dim, eps=eps)
if qk_norm
else Identity()
)
norm_q instance-attribute ¶
norm_q = (
DistributedRMSNorm(tp_inner_dim, eps=eps)
if qk_norm
else Identity()
)
o instance-attribute ¶
q instance-attribute ¶
v instance-attribute ¶
causal_rope_action_apply ¶
causal_rope_action_apply(
x: Tensor,
freqs: Tensor,
freqs_action: Tensor,
freqs_state: Tensor,
action_register_length: int | None,
num_action_per_block: int,
num_state_per_block: int,
action_state_index: int,
) -> Tensor
RoPE for single inference step (causal / KV-cache mode).
rope_action_apply ¶
rope_action_apply(
x: Tensor,
freqs: Tensor,
freqs_action: Tensor,
freqs_state: Tensor,
action_register_length: int | None,
num_action_per_block: int = 32,
num_state_per_block: int = 1,
) -> Tensor
RoPE with action/state frequency tables for multi-step sequences.
rope_apply ¶
Apply RoPE to x using precomputed complex freqs.
rope_params ¶
Precompute complex-valued RoPE frequencies (polar form). Returns: complex tensor [max_seq_len, dim // 2]