Skip to content

vllm_omni.diffusion.models.dreamid_omni.fusion

logger module-attribute

logger = init_logger(__name__)

FusedBlock

Bases: Module

Wrapper pairing a video block and audio block for layerwise offloading.

Registers both blocks as submodules so their parameters are visible to the offload hooks.

audio_block instance-attribute

audio_block = audio_block

device instance-attribute

device = device

vid_block instance-attribute

vid_block = vid_block

forward

forward(
    hidden_states,
    encoder_hidden_states,
    attn: Attention,
    vid_e,
    vid_seq_lens,
    vid_grid_sizes,
    vid_freqs,
    vid_context,
    vid_context_lens,
    vid_ref_lengths,
    vid_freqs_scaling,
    audio_e,
    audio_seq_lens,
    audio_grid_sizes,
    audio_freqs,
    audio_context,
    audio_context_lens,
    audio_ref_lengths,
    audio_freqs_scaling,
)

FusionModel

Bases: Module

attn instance-attribute

attn = Attention(
    num_heads=num_heads,
    head_size=head_dim,
    num_kv_heads=num_heads,
    softmax_scale=1.0 / head_dim**0.5,
    causal=False,
)

audio_model instance-attribute

audio_model = WanModel(
    quant_config=quant_config,
    prefix="audio_model",
    **audio_config,
)

device instance-attribute

device = get_local_device()

full_num_heads instance-attribute

full_num_heads = num_heads

fused_blocks instance-attribute

fused_blocks = ModuleList(
    [
        (FusedBlock(blocks[i], blocks[i], device))
        for i in (range(num_blocks))
    ]
)

head_dim instance-attribute

head_dim = dim // num_heads

num_blocks instance-attribute

num_blocks = len(blocks)

num_heads instance-attribute

num_heads = full_num_heads // tp_size

packed_modules_mapping class-attribute instance-attribute

packed_modules_mapping = {'to_qkv': ['q', 'k', 'v']}

video_model instance-attribute

video_model = WanModel(
    quant_config=quant_config,
    prefix="video_model",
    **video_config,
)

forward

forward(
    vid,
    audio,
    t,
    vid_context,
    audio_context,
    vid_seq_len,
    audio_seq_len,
    ref_ip_lengths=None,
    ref_audio_lengths=None,
    slg_layer=False,
    freqs_scaling=None,
)

inject_cross_attention_kv_projections

inject_cross_attention_kv_projections(
    quant_config: QuantizationConfig | None = None,
)

load_state_dict

load_state_dict(state_dict, strict=True, assign=False)

Remap checkpoints where blocks are stored under video_model.blocks.N.* / audio_model.blocks.N.* to the current fused_blocks.N.vid_block.* / fused_blocks.N.audio_block.*.

merge_kwargs

merge_kwargs(vid_kwargs, audio_kwargs)

keys in each kwarg: e seq_lens grid_sizes freqs context context_lens

set_rope_params

set_rope_params()