Skip to content

vllm_omni.diffusion.models.lance.wan_vae

Wan2.2 VAE used by Lance, ported from upstream so Wan2.2_VAE.pth loads natively without state-dict surgery.

CACHE_T module-attribute

CACHE_T = 2

logger module-attribute

logger = init_logger(__name__)

AttentionBlock

Bases: Module

Single-head causal self-attention over spatial tokens, per frame.

dim instance-attribute

dim = dim

norm instance-attribute

norm = RMS_norm(dim)

proj instance-attribute

proj = Conv2d(dim, dim, 1)

to_qkv instance-attribute

to_qkv = Conv2d(dim, dim * 3, 1)

forward

forward(x)

AvgDown3D

Bases: Module

factor instance-attribute

factor = factor_t * factor_s * factor_s

factor_s instance-attribute

factor_s = factor_s

factor_t instance-attribute

factor_t = factor_t

group_size instance-attribute

group_size = in_channels * factor // out_channels

in_channels instance-attribute

in_channels = in_channels

out_channels instance-attribute

out_channels = out_channels

forward

forward(x: Tensor) -> Tensor

CausalConv3d

Bases: Conv3d

Causal 3D conv with feature-map caching across temporal chunks.

padding instance-attribute

padding = (0, 0, 0)

forward

forward(x, cache_x=None)

Decoder3d

Bases: Module

attn_scales instance-attribute

attn_scales = list(attn_scales)

conv1 instance-attribute

conv1 = CausalConv3d(z_dim, dims[0], 3, padding=1)

dim instance-attribute

dim = dim

dim_mult instance-attribute

dim_mult = list(dim_mult)

head instance-attribute

head = Sequential(
    RMS_norm(out_dim, images=False),
    SiLU(),
    CausalConv3d(out_dim, 12, 3, padding=1),
)

middle instance-attribute

middle = Sequential(
    ResidualBlock(dims[0], dims[0], dropout),
    AttentionBlock(dims[0]),
    ResidualBlock(dims[0], dims[0], dropout),
)

num_res_blocks instance-attribute

num_res_blocks = num_res_blocks

temporal_upsample instance-attribute

temporal_upsample = list(temporal_upsample)

upsamples instance-attribute

upsamples = Sequential(*upsamples)

z_dim instance-attribute

z_dim = z_dim

forward

forward(
    x, feat_cache=None, feat_idx=[0], first_chunk=False
)

Down_ResidualBlock

Bases: Module

avg_shortcut instance-attribute

avg_shortcut = AvgDown3D(
    in_dim,
    out_dim,
    factor_t=2 if temperal_downsample else 1,
    factor_s=2 if down_flag else 1,
)

downsamples instance-attribute

downsamples = Sequential(*downsamples)

forward

forward(x, feat_cache=None, feat_idx=[0])

DupUp3D

Bases: Module

factor instance-attribute

factor = factor_t * factor_s * factor_s

factor_s instance-attribute

factor_s = factor_s

factor_t instance-attribute

factor_t = factor_t

in_channels instance-attribute

in_channels = in_channels

out_channels instance-attribute

out_channels = out_channels

repeats instance-attribute

repeats = out_channels * factor // in_channels

forward

forward(x: Tensor, first_chunk=False) -> Tensor

Encoder3d

Bases: Module

attn_scales instance-attribute

attn_scales = list(attn_scales)

conv1 instance-attribute

conv1 = CausalConv3d(12, dims[0], 3, padding=1)

dim instance-attribute

dim = dim

dim_mult instance-attribute

dim_mult = list(dim_mult)

downsamples instance-attribute

downsamples = Sequential(*downsamples)

head instance-attribute

head = Sequential(
    RMS_norm(out_dim, images=False),
    SiLU(),
    CausalConv3d(out_dim, z_dim, 3, padding=1),
)

middle instance-attribute

middle = Sequential(
    ResidualBlock(out_dim, out_dim, dropout),
    AttentionBlock(out_dim),
    ResidualBlock(out_dim, out_dim, dropout),
)

num_res_blocks instance-attribute

num_res_blocks = num_res_blocks

temperal_downsample instance-attribute

temperal_downsample = list(temperal_downsample)

z_dim instance-attribute

z_dim = z_dim

forward

forward(x, feat_cache=None, feat_idx=[0])

LanceWanVAE

Bases: Module

Wan2.2 VAE wrapped for BAGEL's pipeline.

Exposes BAGEL's image-VAE surface — encode(BCHW) -> BC_zHW and decode(BC_zHW) -> BCHW — by treating each image as a 1-frame video clip. A 5-D encode_video/decode_video path is also provided for the Lance_3B_Video checkpoint.

Construction is lazy: the heavy WanVAE_ and Wan2.2_VAE.pth are not materialized until first use. Once built, the inner module is registered as a submodule so self.parameters(), self.to(device) and vae_dtype = next(vae.parameters()).dtype (used by BAGEL's decode path) all behave.

downsample_spatial class-attribute instance-attribute

downsample_spatial: int = 16

downsample_temporal class-attribute instance-attribute

downsample_temporal: int = 4

z_channels class-attribute instance-attribute

z_channels: int = 48

decode

decode(latent: Tensor) -> Tensor

[B, 48, h, w] -> [B, 3, H, W] (single-frame image path).

decode_video

decode_video(latent: Tensor) -> Tensor

Decode a 5-D latent [B, 48, t, h, w] -> video [B, 3, T, H, W].

encode

encode(padded_images: Tensor) -> Tensor

[B, 3, H, W] -> [B, 48, H/16, W/16] (single-frame image path).

Each image is wrapped as a 1-frame clip, encoded, and the temporal axis is squeezed back out so the result matches BAGEL's iteration pattern.

encode_video

encode_video(
    video: Tensor, *, use_sample: bool = True
) -> Tensor

Encode a 5-D clip [B, 3, T, H, W] -> latent [B, 48, t, h, w].

RMS_norm

Bases: Module

bias instance-attribute

bias = Parameter(zeros(shape)) if bias else 0.0

channel_first instance-attribute

channel_first = channel_first

gamma instance-attribute

gamma = Parameter(ones(shape))

scale instance-attribute

scale = dim ** 0.5

forward

forward(x)

Resample

Bases: Module

dim instance-attribute

dim = dim

mode instance-attribute

mode = mode

resample instance-attribute

resample = Sequential(
    Upsample(scale_factor=(2.0, 2.0), mode="nearest-exact"),
    Conv2d(dim, dim, 3, padding=1),
)

time_conv instance-attribute

time_conv = CausalConv3d(
    dim, dim * 2, (3, 1, 1), padding=(1, 0, 0)
)

forward

forward(x, feat_cache=None, feat_idx=[0])

ResidualBlock

Bases: Module

in_dim instance-attribute

in_dim = in_dim

out_dim instance-attribute

out_dim = out_dim

residual instance-attribute

residual = Sequential(
    RMS_norm(in_dim, images=False),
    SiLU(),
    CausalConv3d(in_dim, out_dim, 3, padding=1),
    RMS_norm(out_dim, images=False),
    SiLU(),
    Dropout(dropout),
    CausalConv3d(out_dim, out_dim, 3, padding=1),
)

shortcut instance-attribute

shortcut = (
    CausalConv3d(in_dim, out_dim, 1)
    if in_dim != out_dim
    else Identity()
)

forward

forward(x, feat_cache=None, feat_idx=[0])

Up_ResidualBlock

Bases: Module

avg_shortcut instance-attribute

avg_shortcut = DupUp3D(
    in_dim,
    out_dim,
    factor_t=2 if temporal_upsample else 1,
    factor_s=2 if up_flag else 1,
)

upsamples instance-attribute

upsamples = Sequential(*upsamples)

forward

forward(
    x, feat_cache=None, feat_idx=[0], first_chunk=False
)

Upsample

Bases: Upsample

forward

forward(x)

WanVAE_

Bases: Module

Upstream Wan2.2 VAE module — encoder3d/decoder3d sandwich with 2x patchify on input. State-dict-compatible with Wan2.2_VAE.pth.

attn_scales instance-attribute

attn_scales = list(attn_scales)

conv1 instance-attribute

conv1 = CausalConv3d(z_dim * 2, z_dim * 2, 1)

conv2 instance-attribute

conv2 = CausalConv3d(z_dim, z_dim, 1)

decoder instance-attribute

decoder = Decoder3d(
    dec_dim,
    z_dim,
    dim_mult,
    num_res_blocks,
    attn_scales,
    temporal_upsample,
    dropout,
)

dim instance-attribute

dim = dim

dim_mult instance-attribute

dim_mult = list(dim_mult)

encoder instance-attribute

encoder = Encoder3d(
    dim,
    z_dim * 2,
    dim_mult,
    num_res_blocks,
    attn_scales,
    temperal_downsample,
    dropout,
)

num_res_blocks instance-attribute

num_res_blocks = num_res_blocks

temperal_downsample instance-attribute

temperal_downsample = list(temperal_downsample)

temporal_upsample instance-attribute

temporal_upsample = list(temperal_downsample)[::(-1)]

z_dim instance-attribute

z_dim = z_dim

clear_cache

clear_cache()

decode

decode(z, scale)

encode

encode(x, scale)

build_wan22_vae

build_wan22_vae(
    vae_path: str, dtype: dtype = bfloat16, device=None
) -> LanceWanVAE

Convenience factory: lazy-construct a :class:LanceWanVAE adapter.