Skip to content

vllm_omni.diffusion.models.wan2_2

Modules:

Name Description
patch_diffusers
pipeline_wan2_2
pipeline_wan2_2_i2v
pipeline_wan2_2_s2v

Wan2.2 Speech-to-Video (S2V) Pipeline for vLLM-Omni.

pipeline_wan2_2_vace

VACE (Video Creation and Editing) Pipeline for WAN models.

scheduling_wan_euler
wan2_2_s2v_transformer

Wan2.2 Speech-to-Video (S2V) Transformer using vllm-omni ops.

wan2_2_transformer
wan2_2_vace_transformer

VACE variant of WanTransformer3DModel for conditional video generation.

VaceWanTransformerBlock

Bases: WanTransformerBlock

VACE variant of WanTransformerBlock with proj_in/proj_out for skip connections.

proj_in instance-attribute

proj_in = Linear(dim, dim) if block_id == 0 else None

proj_out instance-attribute

proj_out = Linear(dim, dim)

forward

forward(
    hidden_states: Tensor,
    encoder_hidden_states: Tensor,
    control_hidden_states: Tensor,
    temb: Tensor,
    rotary_emb: tuple[Tensor, Tensor],
    hidden_states_mask: Tensor | None = None,
) -> tuple[Tensor, Tensor]

Wan22I2VPipeline

Bases: Module, SupportImageInput, PipelineParallelMixin, CFGParallelMixin, ProgressBarMixin, DiffusionPipelineProfilerMixin

Wan2.2 Image-to-Video Pipeline.

Supports both Wan2.1-style I2V (with CLIP image embeddings) and Wan2.2-style I2V (with expand_timesteps for TI2V-5B).

boundary_ratio instance-attribute

boundary_ratio = boundary_ratio

current_timestep property

current_timestep

device instance-attribute

device = get_local_device()

do_classifier_free_guidance property

do_classifier_free_guidance

expand_timesteps instance-attribute

expand_timesteps = getattr(
    od_config, "expand_timesteps", False
)

guidance_scale property

guidance_scale

has_image_encoder instance-attribute

has_image_encoder = (
    "image_encoder" in model_index
    and model_index["image_encoder"][0] is not None
)

has_transformer_2 instance-attribute

has_transformer_2 = 'transformer_2' in model_index

image_encoder instance-attribute

image_encoder = to(device)

image_processor instance-attribute

image_processor = from_pretrained_with_prefetch(
    from_pretrained,
    model,
    subfolder="image_processor",
    prefetch_list=subfolders,
    local_files_only=local_files_only,
)

num_timesteps property

num_timesteps

od_config instance-attribute

od_config = od_config

scheduler instance-attribute

scheduler = build_wan_scheduler(_sample_solver, _flow_shift)

text_encoder instance-attribute

text_encoder = to(device)

tokenizer instance-attribute

tokenizer = from_pretrained_with_prefetch(
    from_pretrained,
    model,
    subfolder="tokenizer",
    prefetch_list=subfolders,
    local_files_only=local_files_only,
)

transformer instance-attribute

transformer = create_transformer_from_config(
    transformer_config, quant_config=quantization_config
)

transformer_2 instance-attribute

transformer_2 = create_transformer_from_config(
    transformer_2_config, quant_config=t2_quant
)

vae instance-attribute

vae = to(device)

vae_scale_factor_spatial instance-attribute

vae_scale_factor_spatial = (
    scale_factor_spatial if hasattr(vae, "config") else 8
)

vae_scale_factor_temporal instance-attribute

vae_scale_factor_temporal = (
    scale_factor_temporal if hasattr(vae, "config") else 4
)

weights_sources instance-attribute

weights_sources = [
    ComponentSource(
        model_or_path=model,
        subfolder="transformer",
        revision=None,
        prefix="transformer.",
        fall_back_to_pt=True,
    )
]

check_inputs

check_inputs(
    prompt,
    negative_prompt,
    image,
    height,
    width,
    prompt_embeds=None,
    negative_prompt_embeds=None,
    image_embeds=None,
    guidance_scale_2=None,
    boundary_ratio=None,
)

diffuse

diffuse(
    latents: Tensor,
    timesteps: Tensor,
    prompt_embeds: Tensor,
    negative_prompt_embeds: Tensor | None,
    image_embeds: Tensor | None,
    guidance_low: float,
    guidance_high: float,
    boundary_timestep: float | None,
    dtype: dtype,
    attention_kwargs: dict[str, Any],
    condition: Tensor,
    first_frame_mask: Tensor,
) -> Tensor | AsyncLatents

encode_image

encode_image(
    image: Image | list[Image], device: device | None = None
) -> Tensor

Encode image using CLIP image encoder.

encode_prompt

encode_prompt(
    prompt: str | list[str],
    negative_prompt: str | list[str] | None = None,
    do_classifier_free_guidance: bool = True,
    num_videos_per_prompt: int = 1,
    max_sequence_length: int = 512,
    device: device | None = None,
    dtype: dtype | None = None,
)

Encode text prompts using T5 text encoder.

forward

forward(
    req: OmniDiffusionRequest,
    prompt: str | None = None,
    negative_prompt: str | None = None,
    image: Image | Tensor | None = None,
    height: int = 480,
    width: int = 832,
    num_inference_steps: int = 40,
    guidance_scale: float | tuple[float, float] = 5.0,
    frame_num: int = 81,
    output_type: str | None = "np",
    generator: Generator | list[Generator] | None = None,
    prompt_embeds: Tensor | None = None,
    negative_prompt_embeds: Tensor | None = None,
    image_embeds: Tensor | None = None,
    last_image: Image | Tensor | None = None,
    attention_kwargs: dict | None = None,
    **kwargs,
) -> DiffusionOutput

load_weights

load_weights(
    weights: Iterable[tuple[str, Tensor]],
) -> set[str]

Load weights using AutoWeightsLoader for vLLM integration.

predict_noise

predict_noise(
    current_model: Module | None = None, **kwargs: Any
) -> Tensor | IntermediateTensors

Forward pass through transformer to predict noise.

Parameters:

Name Type Description Default
current_model Module | None

The transformer model to use (transformer or transformer_2)

None
**kwargs Any

Arguments to pass to the transformer

{}

Returns:

Type Description
Tensor | IntermediateTensors

Predicted noise tensor or IntermediateTensors on non-last PP stages.

prepare_latents

prepare_latents(
    image: Tensor,
    batch_size: int,
    num_channels_latents: int,
    height: int,
    width: int,
    num_frames: int,
    dtype: dtype | None,
    device: device | None,
    generator: Generator | list[Generator] | None,
    latents: Tensor | None = None,
    last_image: Tensor | None = None,
) -> tuple[Tensor, Tensor, Tensor]

Prepare latents for I2V generation.

Returns:

Name Type Description
latents Tensor

Initial noise latents

condition Tensor

Encoded image condition (concatenated with mask for non-expand mode)

first_frame_mask Tensor

Mask for the first frame (1 for frames to denoise, 0 for condition)

Wan22Pipeline

Bases: Module, PipelineParallelMixin, CFGParallelMixin, ProgressBarMixin, DiffusionPipelineProfilerMixin

boundary_ratio instance-attribute

boundary_ratio = boundary_ratio

current_timestep property

current_timestep

device instance-attribute

device = get_local_device()

do_classifier_free_guidance property

do_classifier_free_guidance

expand_timesteps instance-attribute

expand_timesteps = get('expand_timesteps', False)

guidance_scale property

guidance_scale

has_transformer_2 instance-attribute

has_transformer_2 = transformer_2_info[0] is not None

num_timesteps property

num_timesteps

od_config instance-attribute

od_config = od_config

scheduler instance-attribute

scheduler = build_wan_scheduler(_sample_solver, _flow_shift)

text_encoder instance-attribute

text_encoder = to(device)

tokenizer instance-attribute

tokenizer = from_pretrained_with_prefetch(
    from_pretrained,
    model,
    subfolder="tokenizer",
    prefetch_list=component_subfolders,
    local_files_only=local_files_only,
)

transformer instance-attribute

transformer = _create_transformer(transformer_config)

transformer_2 instance-attribute

transformer_2 = _create_transformer(transformer_2_config)

transformer_config instance-attribute

transformer_config = config

vae instance-attribute

vae = to(device)

vae_scale_factor_spatial instance-attribute

vae_scale_factor_spatial = (
    scale_factor_spatial
    if getattr(self, "vae", None)
    else 8
)

vae_scale_factor_temporal instance-attribute

vae_scale_factor_temporal = (
    scale_factor_temporal
    if getattr(self, "vae", None)
    else 4
)

weights_sources instance-attribute

weights_sources = []

check_inputs

check_inputs(
    prompt,
    negative_prompt,
    height,
    width,
    prompt_embeds=None,
    negative_prompt_embeds=None,
    guidance_scale_2=None,
    boundary_ratio=None,
)

diffuse

diffuse(
    latents: Tensor,
    timesteps: Tensor,
    prompt_embeds: Tensor,
    negative_prompt_embeds: Tensor | None,
    guidance_low: float,
    guidance_high: float,
    boundary_timestep: float | None,
    dtype: dtype,
    attention_kwargs: dict[str, Any],
    latent_condition: Tensor | None = None,
    first_frame_mask: Tensor | None = None,
) -> Tensor | AsyncLatents

encode_prompt

encode_prompt(
    prompt: str | list[str],
    negative_prompt: str | list[str] | None = None,
    do_classifier_free_guidance: bool = True,
    num_videos_per_prompt: int = 1,
    max_sequence_length: int = 512,
    device: device | None = None,
    dtype: dtype | None = None,
)

forward

forward(
    req: OmniDiffusionRequest,
    prompt: str | None = None,
    negative_prompt: str | None = None,
    height: int = 480,
    width: int = 832,
    num_inference_steps: int = 40,
    guidance_scale: float | tuple[float, float] = 4.0,
    frame_num: int = 81,
    output_type: str | None = "np",
    generator: Generator | list[Generator] | None = None,
    prompt_embeds: Tensor | None = None,
    negative_prompt_embeds: Tensor | None = None,
    attention_kwargs: dict | None = None,
    **kwargs,
) -> DiffusionOutput

load_weights

load_weights(
    weights: Iterable[tuple[str, Tensor]],
) -> set[str]

Load weights using AutoWeightsLoader for vLLM integration.

predict_noise

predict_noise(
    current_model: Module | None = None, **kwargs: Any
) -> Tensor | IntermediateTensors

Forward pass through transformer to predict noise.

Parameters:

Name Type Description Default
current_model Module | None

The transformer model to use (transformer or transformer_2)

None
**kwargs Any

Arguments to pass to the transformer

{}

Returns:

Type Description
Tensor | IntermediateTensors

Predicted noise tensor or IntermediateTensors on non-last PP stages.

prepare_latents

prepare_latents(
    batch_size: int,
    num_channels_latents: int,
    height: int,
    width: int,
    num_frames: int,
    dtype: dtype | None,
    device: device | None,
    generator: Generator | list[Generator] | None,
    latents: Tensor | None = None,
) -> Tensor

Wan22S2VPipeline

Bases: Module, SupportImageInput, SupportAudioInput, CFGParallelMixin, ProgressBarMixin, DiffusionPipelineProfilerMixin

Wan2.2 Speech-to-Video Pipeline.

Migrated from Wan2.2/wan/speech2video.py (WanS2V).

Key differences from I2V
  • Single transformer (WanModel_S2V), no MoE boundary switching.
  • Audio conditioning via wav2vec2 features injected into transformer.
  • Motion frame autoregressive chaining across multiple clips.
  • Optional pose video conditioning.
  • Reference image encoded as separate ref_latents tokens (not channel-concatenated like I2V).

audio_sample_m instance-attribute

audio_sample_m = 0

current_timestep property

current_timestep

device instance-attribute

device = get_local_device()

do_classifier_free_guidance property

do_classifier_free_guidance

drop_first_motion instance-attribute

drop_first_motion = getattr(
    od_config, "drop_first_motion", True
)

dummy_run_num_frames class-attribute

dummy_run_num_frames: int = 0

fps instance-attribute

fps = (
    getattr(od_config, "fps", _DEFAULT_FPS) or _DEFAULT_FPS
)

guidance_scale property

guidance_scale

motion_frames instance-attribute

motion_frames = getattr(
    od_config, "motion_frames", _DEFAULT_MOTION_FRAMES
)

num_timesteps property

num_timesteps

od_config instance-attribute

od_config = od_config

resolution_divisor instance-attribute

resolution_divisor = vae_scale_factor_spatial * 2

scheduler instance-attribute

scheduler = FlowUniPCMultistepScheduler(
    num_train_timesteps=1000,
    shift=flow_shift,
    prediction_type="flow_prediction",
)

vae_scale_factor_spatial instance-attribute

vae_scale_factor_spatial = (
    scale_factor_spatial if hasattr(vae, "config") else 8
)

vae_scale_factor_temporal instance-attribute

vae_scale_factor_temporal = (
    scale_factor_temporal if hasattr(vae, "config") else 4
)

check_inputs

check_inputs(
    prompt: str | None,
    image: Image | None,
    audio_path: str | None,
    height: int,
    width: int,
    prompt_embeds: Tensor | None = None,
)

diffuse

diffuse(
    latents: Tensor,
    timesteps: Tensor,
    prompt_embeds: Tensor,
    negative_prompt_embeds: Tensor | None,
    guidance_scale: float,
    clip_generator: Generator,
    dtype: dtype,
    device: device,
    max_seq_len: int,
    cond_latents: Tensor,
    input_motion_latents: Tensor,
    ref_latents: Tensor,
    motion_frames: list[int],
    drop_first_motion: bool,
    positive_audio_emb: Tensor,
    negative_audio_emb: Tensor | None,
) -> Tensor

Denoising diffusion loop for one S2V clip.

Parameters:

Name Type Description Default
latents Tensor

Initial noise tensor [C, T, H, W]

required
timesteps Tensor

Denoising timesteps

required
prompt_embeds Tensor

Text embeddings [1, seq_len, dim]

required
negative_prompt_embeds Tensor | None

Negative text embeddings (optional)

required
guidance_scale float

CFG scale

required
clip_generator Generator

Random generator for this clip

required
dtype dtype

Data type for computation

required
device device

Device for computation

required
max_seq_len int

Maximum sequence length for transformer

required
cond_latents Tensor

Pose condition latents [1, 16, T, H, W]

required
input_motion_latents Tensor

Motion latents from previous clip

required
ref_latents Tensor

Reference image latents

required
motion_frames list[int]

Motion frame counts [pixel_frames, latent_frames]

required
drop_first_motion bool

Whether to drop first motion frames (first clip only)

required
positive_audio_emb Tensor

Precomputed audio embeddings for positive prompt

required
negative_audio_emb Tensor | None

Precomputed audio embeddings for negative prompt (optional)

required

Returns:

Type Description
Tensor

Denoised latents [C, T, H, W]

encode_audio

encode_audio(
    audio_path: str | ndarray,
    infer_frames: int,
    device: device | None = None,
    dtype: dtype | None = None,
) -> tuple[Tensor, int, int]

Extract wav2vec2 audio features and bucket them to video frame-rate.

Ported from WanS2V.encode_audio and AudioEncoder.

Parameters:

Name Type Description Default
audio_path str | ndarray

Path to audio file, or raw numpy audio array (16 kHz).

required
infer_frames int

Number of video frames per clip.

required

Returns:

Type Description
Tensor

(audio_embed_bucket, num_repeat, target_video_frames) — audio

int

embeddings aligned to video frames with shape

int

[1, num_layers, C_a, T_total], the number of clips needed, and

tuple[Tensor, int, int]

the exact number of video frames to match audio duration.

encode_prompt

encode_prompt(
    prompt: str | list[str],
    negative_prompt: str | list[str] | None = None,
    do_classifier_free_guidance: bool = True,
    num_videos_per_prompt: int = 1,
    max_sequence_length: int = 512,
    device: device | None = None,
    dtype: dtype | None = None,
) -> tuple[Tensor, Tensor | None]

Encode text prompts using T5 text encoder.

Identical to Wan22I2VPipeline.encode_prompt.

encode_ref_image

encode_ref_image(
    image: Image,
    height: int,
    width: int,
    device: device | None = None,
) -> Tensor

VAE-encode the reference image into latent space.

Ported from WanS2V.generate lines 492-499.

Parameters:

Name Type Description Default
image Image

Reference PIL image (already resized/cropped).

required
height int

Target height.

required
width int

Target width.

required

Returns:

Type Description
Tensor

ref_latents with shape [1, C, 1, H_lat, W_lat].

forward

forward(
    req: OmniDiffusionRequest,
    prompt: str | None = None,
    negative_prompt: str | None = None,
    image: Image | None = None,
    audio_path: str | None = None,
    height: int = 704,
    width: int = 1024,
    infer_frames: int | None = None,
    num_inference_steps: int = 40,
    guidance_scale: float = 4.5,
    num_repeat: int | None = None,
    pose_video: str | None = None,
    init_first_frame: bool = False,
    output_type: str | None = "np",
    generator: Generator | None = None,
    prompt_embeds: Tensor | None = None,
    negative_prompt_embeds: Tensor | None = None,
    **kwargs,
) -> DiffusionOutput

Run S2V generation — may produce multiple autoregressive clips.

This method mirrors WanS2V.generate(), reorganized into the vLLM-Omni pipeline structure.

load_weights

load_weights(
    weights: Iterable[tuple[str, Tensor]],
) -> set[str]

Load weights using AutoWeightsLoader for vLLM integration.

predict_noise

predict_noise(
    current_model: Module | None = None, **kwargs: Any
) -> Tensor

Forward pass through the S2V transformer to predict noise.

With return_dict=False, the model returns (output,) where output is [B, C, T, H, W]. We squeeze batch dim since S2V processes one sample at a time.

prepare_latents

prepare_latents(
    infer_frames: int,
    height: int,
    width: int,
    motion_frames: int,
    lat_motion_frames: int,
    dtype: dtype,
    device: device,
    generator: Generator | None = None,
) -> Tensor

Generate random noise latents for one S2V clip.

Ported from WanS2V.generate lines 544-556.

The target latent shape accounts for both the denoised portion and excludes the motion prefix (which is prepended separately at decode time).

Returns:

Type Description
Tensor

Noise tensor with shape [C, T_lat, H_lat, W_lat] (no batch

Tensor

dim, matching the original WanModel_S2V convention).

prepare_motion_latents

prepare_motion_latents(
    motion_pixels: Tensor, device: device | None = None
) -> Tensor

VAE-encode motion frame pixels into latent space.

Parameters:

Name Type Description Default
motion_pixels Tensor

Shape [1, C, T_motion, H, W], pixel values in [-1, 1].

required

Returns:

Type Description
Tensor

motion_latents with shape [1, C_lat, T_lat, H_lat, W_lat].

Wan22VACEPipeline

Bases: Wan22Pipeline, SupportImageInput

VACE (Video Creation and Editing) Pipeline for Wan2.1.

Extends Wan22Pipeline with VACE-specific context creation and weight loading. All VACE modes (T2V, R2V, V2V, MV2V) are handled by varying the inputs.

check_inputs

check_inputs(
    prompt,
    negative_prompt,
    height,
    width,
    prompt_embeds=None,
    negative_prompt_embeds=None,
    video=None,
    mask=None,
    reference_images=None,
)

diffuse

diffuse(
    latents: Tensor,
    timesteps: Tensor,
    prompt_embeds: Tensor,
    negative_prompt_embeds: Tensor | None,
    guidance_scale: float,
    dtype: dtype,
    attention_kwargs: dict[str, object],
    vace_context: Tensor | None,
    vace_context_scale: float,
) -> Tensor

forward

forward(
    req: OmniDiffusionRequest,
    prompt: str | None = None,
    negative_prompt: str | None = None,
    height: int = 480,
    width: int = 832,
    num_inference_steps: int = 50,
    guidance_scale: float = 5.0,
    frame_num: int = 81,
    output_type: str | None = "np",
    generator: Generator | list[Generator] | None = None,
    prompt_embeds: Tensor | None = None,
    negative_prompt_embeds: Tensor | None = None,
    attention_kwargs: dict | None = None,
    vace_context_scale: float | list[float] = 1.0,
    **kwargs,
) -> DiffusionOutput

Generate or edit video using VACE.

The mode is determined by which inputs are provided in the request: - T2V: prompt only (no video/mask/reference_images) - R2V: prompt + reference_images (in multi_modal_data) - V2V: prompt + video (in multi_modal_data) - MV2V: prompt + video + mask (in multi_modal_data)

Parameters:

Name Type Description Default
req OmniDiffusionRequest

Diffusion request containing prompt and optional multi-modal data.

required
prompt str | None

Text prompt (overridden by req.prompts if provided).

None
negative_prompt str | None

Negative prompt for CFG.

None
height int

Output video height.

480
width int

Output video width.

832
num_inference_steps int

Number of denoising steps.

50
guidance_scale float

CFG scale.

5.0
frame_num int

Number of output frames.

81
output_type str | None

Output format ("np", "pt", or "latent").

'np'
generator Generator | list[Generator] | None

Random generator for reproducibility.

None
prompt_embeds Tensor | None

Pre-computed prompt embeddings.

None
negative_prompt_embeds Tensor | None

Pre-computed negative prompt embeddings.

None
attention_kwargs dict | None

Additional kwargs for attention layers.

None
vace_context_scale float | list[float]

VACE conditioning strength.

1.0

load_weights

load_weights(
    weights: Iterable[tuple[str, Tensor]],
) -> set[str]

Load weights using AutoWeightsLoader for vLLM integration.

prepare_masks

prepare_masks(
    mask: Tensor, reference_images: list[list[Tensor]]
) -> Tensor

Encode mask using spatial stride sampling and prepend reference padding.

  • 8x8 spatial stride encoding -> 64 channels
  • Zero-masks prepended for reference image frames

prepare_video_latents

prepare_video_latents(
    video: Tensor,
    mask: Tensor,
    reference_images: list[list[Tensor]],
    generator: Generator | None,
    device: device,
) -> Tensor

Encode video and reference images into VACE conditioning latents.

  • Encodes inactive (video * (1-mask)) and reactive (video * mask) regions
  • Reference images are encoded and prepended as extra temporal frames

preprocess_conditions

preprocess_conditions(
    video: list | Tensor | None,
    mask: list | Tensor | None,
    reference_images: list[Image] | None,
    height: int,
    width: int,
    num_frames: int,
    dtype: dtype,
    device: device,
) -> tuple[Tensor, Tensor, list[list[Tensor]]]

Preprocess video, mask, and reference images for VACE conditioning.

  • If video is None, create zero tensor (T2V mode)
  • If mask is None, create all-ones tensor (generate everything)
  • Reference images are resized maintaining aspect ratio and center-padded

Returns:

Type Description
tuple[Tensor, Tensor, list[list[Tensor]]]

(video, mask, reference_images_processed) tensors ready for VAE encoding.

WanI2VDMD2Pipeline

Bases: DMD2PipelineMixin, Wan22I2VPipeline

Wan 2.x I2V pipeline for FastGen DMD2-distilled models.

WanT2VDMD2Pipeline

Bases: DMD2PipelineMixin, Wan22Pipeline

Wan 2.x T2V pipeline for FastGen DMD2-distilled models.

WanTransformer3DModel

Bases: Module

Optimized Wan Transformer model for video generation using vLLM layers.

This is an optimized version of the diffusers WanTransformer3DModel that uses vLLM's efficient QKVParallelLinear and RMSNorm implementations.

Sequence Parallelism

This model supports non-intrusive SP via _sp_plan. The plan specifies: - RoPE (cos/sin) splitting via rope module's split_output - hidden_states splitting at first transformer block input - Output gathering at proj_out layer

The video sequence (flattened patches) is parallelized across GPUs.

Note: Our "Sequence Parallelism" (SP) corresponds to "Context Parallelism" (CP) in diffusers.

Parameters:

Name Type Description Default
patch_size tuple[int, int, int]

3D patch dimensions for video embedding (t_patch, h_patch, w_patch)

(1, 2, 2)
num_attention_heads int

Number of attention heads

40
attention_head_dim int

Dimension of each attention head

128
in_channels int

Number of input channels

16
out_channels int

Number of output channels

16
text_dim int

Input dimension for text embeddings

4096
freq_dim int

Dimension for sinusoidal time embeddings

256
ffn_dim int

Intermediate dimension in feed-forward network

13824
num_layers int

Number of transformer blocks

40
cross_attn_norm bool

Enable cross-attention normalization

True
eps float

Epsilon value for normalization layers

1e-06
image_dim int | None

Optional image embedding dimension for I2V

None
added_kv_proj_dim int | None

Optional added KV projection dimension for I2V

None
rope_max_seq_len int

Maximum sequence length for rotary embeddings

1024
pos_embed_seq_len int | None

Optional position embedding sequence length

None

condition_embedder instance-attribute

condition_embedder = WanTimeTextImageEmbedding(
    dim=inner_dim,
    time_freq_dim=freq_dim,
    time_proj_dim=inner_dim * 6,
    text_embed_dim=text_dim,
    image_embed_dim=image_dim,
    pos_embed_seq_len=pos_embed_seq_len,
)

config instance-attribute

config = type(
    "Config",
    (),
    {
        "patch_size": patch_size,
        "num_attention_heads": num_attention_heads,
        "attention_head_dim": attention_head_dim,
        "in_channels": in_channels,
        "out_channels": out_channels,
        "text_dim": text_dim,
        "freq_dim": freq_dim,
        "ffn_dim": ffn_dim,
        "num_layers": num_layers,
        "cross_attn_norm": cross_attn_norm,
        "eps": eps,
        "image_dim": image_dim,
        "added_kv_proj_dim": added_kv_proj_dim,
        "rope_max_seq_len": rope_max_seq_len,
        "pos_embed_seq_len": pos_embed_seq_len,
    },
)()

dtype property

dtype: dtype

Return the dtype of the model parameters.

make_empty_intermediate_tensors instance-attribute

make_empty_intermediate_tensors = (
    make_empty_intermediate_tensors_factory(
        ["hidden_states"], inner_dim
    )
)

norm_out instance-attribute

norm_out = AdaLayerNorm(
    inner_dim, elementwise_affine=False, eps=eps
)

output_scale_shift_prepare instance-attribute

output_scale_shift_prepare = OutputScaleShiftPrepare(
    inner_dim
)

packed_modules_mapping class-attribute instance-attribute

packed_modules_mapping = {
    "to_qkv": ["to_q", "to_k", "to_v"]
}

patch_embedding instance-attribute

patch_embedding = Conv3dLayer(
    in_channels=in_channels,
    out_channels=inner_dim,
    kernel_size=patch_size,
    stride=patch_size,
)

proj_out instance-attribute

proj_out = Linear(
    inner_dim, out_channels * prod(patch_size)
)

rope instance-attribute

rope = WanRotaryPosEmbed(
    attention_head_dim, patch_size, rope_max_seq_len
)

timestep_proj_prepare instance-attribute

timestep_proj_prepare = TimestepProjPrepare()

forward

forward(
    hidden_states: Tensor,
    timestep: LongTensor,
    encoder_hidden_states: Tensor,
    encoder_hidden_states_image: Tensor | None = None,
    intermediate_tensors: IntermediateTensors | None = None,
    return_dict: bool = True,
    attention_kwargs: dict[str, Any] | None = None,
) -> (
    Tensor | Transformer2DModelOutput | IntermediateTensors
)

load_weights

load_weights(
    weights: Iterable[tuple[str, Tensor]],
) -> set[str]

Load weights from a pretrained model, handling the mapping from separate Q/K/V projections to fused QKV projections for self-attention.

Diffusers weight names: - blocks.N.attn1.to_q/to_k/to_v -> fused to blocks.N.attn1.to_qkv (self-attention) - blocks.N.attn2.to_q/to_k/to_v -> kept separate (cross-attention) - blocks.N.attn1.norm_q/norm_k -> QK normalization for self-attention

Returns:

Type Description
set[str]

Set of parameter names that were successfully loaded.

WanVACETransformer3DModel

Bases: WanTransformer3DModel

VACE-extended WAN Transformer with conditioning blocks for video editing.

vace_blocks instance-attribute

vace_blocks = None

vace_layers instance-attribute

vace_layers = None

vace_layers_mapping instance-attribute

vace_layers_mapping = None

vace_patch_embedding instance-attribute

vace_patch_embedding = None

embed_vace_context

embed_vace_context(
    vace_context: Tensor, seq_len: int, sp_size: int = 1
) -> Tensor

Compute VACE patch embeddings, aligned and sharded for SP.

Parameters:

Name Type Description Default
vace_context Tensor

Raw conditioning tensor [B, C, T, H, W].

required
seq_len int

Target full (padded) sequence length to align to.

required
sp_size int

Sequence parallel world size.

1

forward

forward(
    hidden_states: Tensor,
    timestep: LongTensor,
    encoder_hidden_states: Tensor,
    encoder_hidden_states_image: Tensor | None = None,
    return_dict: bool = True,
    attention_kwargs: dict[str, Any] | None = None,
    vace_context: Tensor | None = None,
    vace_context_scale: float | list[float] = 1.0,
) -> Tensor | Transformer2DModelOutput

create_transformer_from_config

create_transformer_from_config(
    config: dict,
    quant_config: QuantizationConfig | None = None,
    prefix: str = "",
) -> WanTransformer3DModel

Create WanTransformer3DModel from config dict.

get_wan22_i2v_post_process_func

get_wan22_i2v_post_process_func(
    od_config: OmniDiffusionConfig,
)

get_wan22_i2v_pre_process_func

get_wan22_i2v_pre_process_func(
    od_config: OmniDiffusionConfig,
)

Pre-process function for I2V: load and resize input image.

get_wan22_post_process_func

get_wan22_post_process_func(od_config: OmniDiffusionConfig)

get_wan22_pre_process_func

get_wan22_pre_process_func(od_config: OmniDiffusionConfig)

Pre-process function for Wan2.2: optionally load and resize input image for I2V mode.

get_wan22_s2v_post_process_func

get_wan22_s2v_post_process_func(
    od_config: OmniDiffusionConfig,
)

get_wan22_s2v_pre_process_func

get_wan22_s2v_pre_process_func(
    od_config: OmniDiffusionConfig,
)

Pre-process function for S2V: load ref image, compute target size.

Expects multi_modal_data to contain: - "image": reference image (PIL.Image or file path) - "audio": audio file path (str)

Optionally
  • "pose_video": pose conditioning video path (str)
  • "init_first_frame": bool, use ref image as first frame

get_wan22_vace_post_process_func

get_wan22_vace_post_process_func(
    od_config: OmniDiffusionConfig,
)

get_wan22_vace_pre_process_func

get_wan22_vace_pre_process_func(
    od_config: OmniDiffusionConfig,
)

Pre-process function for VACE: handle reference images, source videos, and masks.

load_transformer_config

load_transformer_config(
    model_path: str,
    subfolder: str = "transformer",
    local_files_only: bool = True,
) -> dict

Load transformer config from model directory or HF Hub.

retrieve_latents

retrieve_latents(
    encoder_output: Tensor,
    generator: Generator | None = None,
    sample_mode: str = "sample",
)

Retrieve latents from VAE encoder output.