Skip to content

vllm_omni.diffusion.models.magi_human.magi_human_dit

HAS_MAGI_ATTENTION module-attribute

HAS_MAGI_ATTENTION = find_spec("magi_attention") is not None

is_base_model module-attribute

is_base_model = True

Adapter

Bases: Module

audio_embedder instance-attribute

audio_embedder = Linear(
    audio_in_channels, hidden_size, bias=True, dtype=float32
)

config instance-attribute

config: AdapterConfig = config

rope instance-attribute

rope = ElementWiseFourierEmbed(
    hidden_size // num_attention_heads,
    in_pixels=False,
    learnable=False,
)

text_embedder instance-attribute

text_embedder = Linear(
    text_in_channels, hidden_size, bias=True, dtype=float32
)

video_embedder instance-attribute

video_embedder = Linear(
    video_in_channels, hidden_size, bias=True, dtype=float32
)

forward

forward(
    x, coords_mapping, video_mask, audio_mask, text_mask
)

AdapterConfig dataclass

audio_in_channels instance-attribute

audio_in_channels: int

hidden_size instance-attribute

hidden_size: int

num_attention_heads instance-attribute

num_attention_heads: int

params_dtype instance-attribute

params_dtype: dtype

text_in_channels instance-attribute

text_in_channels: int

video_in_channels instance-attribute

video_in_channels: int

Attention

Bases: Module

config instance-attribute

config: AttentionConfig = config

gating_size instance-attribute

gating_size = num_heads_q if enable_attn_gating else 0

k_norm instance-attribute

k_norm = MultiModalityRMSNorm(
    head_dim, num_modality=num_modality
)

kv_size instance-attribute

kv_size = num_kv_heads * head_dim

linear_gating instance-attribute

linear_gating = ColumnParallelLinear(
    input_size=hidden_size,
    output_size=num_heads_q,
    bias=False,
    gather_output=False,
    return_bias=False,
)

linear_proj instance-attribute

linear_proj = RowParallelLinear(
    input_size=num_heads_q * head_dim,
    output_size=hidden_size,
    bias=False,
    input_is_parallel=True,
    return_bias=False,
)

linear_qkv instance-attribute

linear_qkv = QKVParallelLinear(
    hidden_size=hidden_size,
    head_size=head_dim,
    total_num_heads=num_heads_q,
    total_num_kv_heads=num_heads_kv,
    bias=False,
    return_bias=False,
)

pre_norm instance-attribute

pre_norm = MultiModalityRMSNorm(
    hidden_size, eps=1e-06, num_modality=num_modality
)

q_norm instance-attribute

q_norm = MultiModalityRMSNorm(
    head_dim, num_modality=num_modality
)

q_size instance-attribute

q_size = num_heads * head_dim

forward

forward(
    hidden_states: Tensor,
    rope: Tensor,
    permute_mapping: Tensor,
    inv_permute_mapping: Tensor,
    varlen_handler: VarlenHandler,
    local_attn_handler: FFAHandler | None,
    modality_dispatcher: ModalityDispatcher,
) -> Tensor

AttentionConfig dataclass

checkpoint_qk_layernorm_rope instance-attribute

checkpoint_qk_layernorm_rope: bool

enable_attn_gating class-attribute instance-attribute

enable_attn_gating: bool = False

head_dim instance-attribute

head_dim: int

hidden_size instance-attribute

hidden_size: int

num_heads_kv instance-attribute

num_heads_kv: int

num_heads_q instance-attribute

num_heads_q: int

num_layers instance-attribute

num_layers: int

num_modality instance-attribute

num_modality: int

params_dtype instance-attribute

params_dtype: dtype

use_local_attn class-attribute instance-attribute

use_local_attn: bool = False

BaseLinear

Bases: Module

bias instance-attribute

bias = Parameter(
    empty(out_features * num_experts, **factory_kwargs)
)

in_features instance-attribute

in_features = in_features

num_experts instance-attribute

num_experts = num_experts

num_layers_for_initialization instance-attribute

num_layers_for_initialization = (
    num_layers_for_initialization
)

out_features instance-attribute

out_features = out_features

use_bias instance-attribute

use_bias = bias

weight instance-attribute

weight = Parameter(
    empty(
        (out_features * num_experts, in_features),
        **factory_kwargs,
    )
)

forward

forward(
    input: Tensor,
    output_dtype: dtype | None = None,
    modality_dispatcher: ModalityDispatcher | None = None,
) -> Tensor

DiTModel

Bases: Module

adapter instance-attribute

adapter: Adapter = Adapter(adapter_config)

block instance-attribute

block: TransformerBlock = TransformerBlock(
    model_config=model_config
)

blocks property

blocks: ModuleList

config instance-attribute

config: TransformerConfig = TransformerConfig(
    hidden_size=hidden_size,
    video_in_channels=video_in_channels,
    audio_in_channels=audio_in_channels,
    text_in_channels=text_in_channels,
    params_dtype=params_dtype,
    post_process_dtype=float32,
)

final_linear_audio instance-attribute

final_linear_audio = Linear(
    hidden_size,
    audio_in_channels,
    bias=False,
    dtype=float32,
)

final_linear_video instance-attribute

final_linear_video = Linear(
    hidden_size,
    video_in_channels,
    bias=False,
    dtype=float32,
)

final_norm_audio instance-attribute

final_norm_audio = MultiModalityRMSNorm(hidden_size)

final_norm_video instance-attribute

final_norm_video = MultiModalityRMSNorm(hidden_size)

forward

forward(
    x: Tensor,
    coords_mapping: Tensor,
    modality_mapping: Tensor,
    varlen_handler: VarlenHandler,
    local_attn_handler: FFAHandler | None,
)

ElementWiseFourierEmbed

Bases: Module

bands instance-attribute

bands = Parameter(bands, requires_grad=learnable)

device instance-attribute

device = device

dim instance-attribute

dim = dim

dtype instance-attribute

dtype = dtype

in_pixels instance-attribute

in_pixels = in_pixels

learnable instance-attribute

learnable = learnable

linear_bands instance-attribute

linear_bands = linear_bands

max_res instance-attribute

max_res = max_res

temperature instance-attribute

temperature = temperature

forward

forward(coords: Tensor) -> Tensor

get_default_bands

get_default_bands()

reset_parameters

reset_parameters()

FFAHandler dataclass

attn_type_map instance-attribute

attn_type_map: Tensor

k_ranges instance-attribute

k_ranges: Tensor

max_seqlen_k instance-attribute

max_seqlen_k: int

max_seqlen_q instance-attribute

max_seqlen_q: int

q_ranges instance-attribute

q_ranges: Tensor

softmax_scale instance-attribute

softmax_scale: float

MLP

Bases: Module

activation_func instance-attribute

activation_func = create_activation_func(activation_type)

config instance-attribute

config: MLPConfig

down_proj instance-attribute

down_proj = RowParallelLinear(
    input_size=intermediate_size,
    output_size=hidden_size,
    bias=False,
    input_is_parallel=True,
    return_bias=False,
)

pre_norm instance-attribute

pre_norm = MultiModalityRMSNorm(
    hidden_size, num_modality=num_modality
)

up_gate_proj instance-attribute

up_gate_proj = ColumnParallelLinear(
    input_size=hidden_size,
    output_size=intermediate_size_up,
    bias=False,
    gather_output=False,
    return_bias=False,
)

forward

forward(
    x: Tensor, modality_dispatcher: ModalityDispatcher
) -> Tensor

MLPActivationType

Bases: Enum

GELU7 class-attribute instance-attribute

GELU7 = 'gelu7'

SWIGLU7 class-attribute instance-attribute

SWIGLU7 = 'swiglu7'

MLPConfig dataclass

activation_type instance-attribute

activation_type: MLPActivationType

gated_act class-attribute instance-attribute

gated_act: bool = False

hidden_size instance-attribute

hidden_size: int

intermediate_size instance-attribute

intermediate_size: int

num_layers class-attribute instance-attribute

num_layers: int = 1

num_modality class-attribute instance-attribute

num_modality: int = 1

params_dtype instance-attribute

params_dtype: dtype

MagiDataProxy

coords_style instance-attribute

coords_style = coords_style

frame_receptive_field instance-attribute

frame_receptive_field = frame_receptive_field

patch_size instance-attribute

patch_size = patch_size

ref_audio_offset instance-attribute

ref_audio_offset = ref_audio_offset

spatial_rope_interpolation instance-attribute

spatial_rope_interpolation = spatial_rope_interpolation

t_patch_size instance-attribute

t_patch_size = t_patch_size

text_offset instance-attribute

text_offset = text_offset

get_saved_data

get_saved_data(key: str)

img2tokens

img2tokens(x_t: Tensor)

process_input

process_input(transported_data: EvalInput)

process_output

process_output(x: Tensor)

saved_for_output

saved_for_output(**kwargs)

MagiHumanDiTConfig dataclass

audio_in_channels class-attribute instance-attribute

audio_in_channels: int = 64

checkpoint_qk_layernorm_rope class-attribute instance-attribute

checkpoint_qk_layernorm_rope: bool = False

enable_attn_gating class-attribute instance-attribute

enable_attn_gating: bool = True

gelu7_layers class-attribute instance-attribute

gelu7_layers: list = field(
    default_factory=lambda: [0, 1, 2, 3]
)

head_dim class-attribute instance-attribute

head_dim: int = 128

hidden_size class-attribute instance-attribute

hidden_size: int = 5120

local_attn_layers class-attribute instance-attribute

local_attn_layers: list = field(default_factory=list)

mm_layers class-attribute instance-attribute

mm_layers: list = field(
    default_factory=lambda: [0, 1, 2, 3, 36, 37, 38, 39]
)

num_layers class-attribute instance-attribute

num_layers: int = 40

num_query_groups class-attribute instance-attribute

num_query_groups: int = 8

params_dtype class-attribute instance-attribute

params_dtype: dtype = float32

post_norm_layers class-attribute instance-attribute

post_norm_layers: list = field(default_factory=list)

text_in_channels class-attribute instance-attribute

text_in_channels: int = 3584

video_in_channels class-attribute instance-attribute

video_in_channels: int = 48 * 4

MoEColumnParallelLinear

Bases: Module

Per-expert ColumnParallelLinear with modality dispatch.

Forward: dispatch → per-expert column-parallel matmul → undispatch. Output stays TP-local (no gather).

experts instance-attribute

experts = ModuleList(
    [
        (
            ColumnParallelLinear(
                input_size=input_size,
                output_size=output_size,
                bias=bias,
                gather_output=False,
                return_bias=False,
            )
        )
        for _ in (range(num_experts))
    ]
)

num_experts instance-attribute

num_experts = num_experts

forward

forward(
    x: Tensor, modality_dispatcher: ModalityDispatcher
) -> Tensor

MoEQKVParallelLinear

Bases: Module

Per-expert QKVParallelLinear with modality dispatch.

Wraps num_experts independent QKVParallelLinear instances. Forward: dispatch tokens by modality → per-expert QKV matmul (TP-sharded) → undispatch.

experts instance-attribute

experts = ModuleList(
    [
        (
            QKVParallelLinear(
                hidden_size=hidden_size,
                head_size=head_size,
                total_num_heads=total_num_heads,
                total_num_kv_heads=total_num_kv_heads,
                bias=bias,
                return_bias=False,
            )
        )
        for _ in (range(num_experts))
    ]
)

head_size instance-attribute

head_size = head_size

num_experts instance-attribute

num_experts = num_experts

num_heads instance-attribute

num_heads = num_heads

num_kv_heads instance-attribute

num_kv_heads = num_kv_heads

forward

forward(
    x: Tensor, modality_dispatcher: ModalityDispatcher
) -> Tensor

MoERowParallelLinear

Bases: Module

Per-expert RowParallelLinear with modality dispatch.

Forward: dispatch → per-expert row-parallel matmul (includes all-reduce) → undispatch.

experts instance-attribute

experts = ModuleList(
    [
        (
            RowParallelLinear(
                input_size=input_size,
                output_size=output_size,
                bias=bias,
                input_is_parallel=True,
                return_bias=False,
            )
        )
        for _ in (range(num_experts))
    ]
)

num_experts instance-attribute

num_experts = num_experts

forward

forward(
    x: Tensor, modality_dispatcher: ModalityDispatcher
) -> Tensor

Modality

Bases: IntEnum

AUDIO class-attribute instance-attribute

AUDIO = 1

TEXT class-attribute instance-attribute

TEXT = 2

VIDEO class-attribute instance-attribute

VIDEO = 0

ModalityDispatcher

group_size instance-attribute

group_size: Tensor = to(int32)

group_size_cpu instance-attribute

group_size_cpu: list[int] = [(int(x)) for x in (tolist())]

modality_mapping instance-attribute

modality_mapping = modality_mapping

num_modalities instance-attribute

num_modalities: int = num_modalities

permuted_modality_mapping instance-attribute

permuted_modality_mapping: Tensor = (
    _precompute_permute_mapping(modality_mapping)
)

dispatch

dispatch(x: Tensor) -> list[Tensor]

inv_permute staticmethod

inv_permute(
    x: Tensor, inv_permute_mapping: Tensor
) -> Tensor

permute staticmethod

permute(x: Tensor, permute_mapping: Tensor) -> Tensor

undispatch

undispatch(*processed_groups: list[Tensor]) -> Tensor

MultiModalityRMSNorm

Bases: Module

dim instance-attribute

dim = dim

eps instance-attribute

eps = eps

forward instance-attribute

forward = forward_multi_experts

num_modality instance-attribute

num_modality = num_modality

weight instance-attribute

weight = Parameter(
    zeros(dim * num_modality, device=device, dtype=float32)
)

forward_multi_experts

forward_multi_experts(
    x: Tensor, modality_dispatcher: ModalityDispatcher
) -> Tensor

forward_single_expert

forward_single_expert(
    x: Tensor,
    modality_dispatcher: ModalityDispatcher | None = None,
) -> Tensor

reset_parameters

reset_parameters()

rms

rms(x: Tensor) -> Tensor

NativeMoELinear

Bases: BaseLinear

forward

forward(
    input: Tensor,
    output_dtype: dtype | None = None,
    modality_dispatcher: ModalityDispatcher | None = None,
) -> Tensor

SimplePackedData dataclass

coords_mapping property

coords_mapping

cu_seqlen property

cu_seqlen

items instance-attribute

items: list[SingleData]

max_seqlen property

max_seqlen

modality_mapping property

modality_mapping

token_sequence property

token_sequence

total_token_num property

total_token_num

depack_token_sequence

depack_token_sequence(token_sequence)

SingleData dataclass

audio_feat_len instance-attribute

audio_feat_len: int

audio_x_t instance-attribute

audio_x_t: Tensor

coords_mapping property

coords_mapping

coords_style class-attribute instance-attribute

coords_style: Literal['v1', 'v2'] = 'v1'

default_dtype property

default_dtype

device property

device

h instance-attribute

h: int

modality_mapping property

modality_mapping

patch_size instance-attribute

patch_size: int

ref_audio_offset instance-attribute

ref_audio_offset: int

spatial_rope_interpolation instance-attribute

spatial_rope_interpolation: Literal['inter', 'extra']

t instance-attribute

t: int

t_patch_size instance-attribute

t_patch_size: int

text_offset instance-attribute

text_offset: int

token_sequence property

token_sequence

total_token_num property

total_token_num

txt_feat instance-attribute

txt_feat: Tensor

txt_feat_len instance-attribute

txt_feat_len: int

video_x_t instance-attribute

video_x_t: Tensor

w instance-attribute

w: int

default_coords

default_coords(shape, ref_feat_shape, offset_thw=None)

depack_token_sequence

depack_token_sequence(token_sequence)

TransFormerLayer

Bases: Module

attention instance-attribute

attention: Attention = Attention(attention_config)

attn_post_norm instance-attribute

attn_post_norm = MultiModalityRMSNorm(
    hidden_size, num_modality=num_modality
)

mlp instance-attribute

mlp: MLP = MLP(mlp_config)

mlp_post_norm instance-attribute

mlp_post_norm = MultiModalityRMSNorm(
    hidden_size, num_modality=num_modality
)

post_norm instance-attribute

post_norm = layer_idx in post_norm_layers

forward

forward(
    hidden_states: Tensor,
    rope: Tensor,
    permute_mapping: Tensor,
    inv_permute_mapping: Tensor,
    varlen_handler: VarlenHandler,
    local_attn_handler: FFAHandler | None,
    modality_dispatcher: ModalityDispatcher,
) -> Tensor

TransformerBlock

Bases: Module

layers instance-attribute

layers: list[TransFormerLayer] = ModuleList()

forward

forward(
    x: Tensor,
    rope: Tensor,
    permute_mapping: Tensor,
    inv_permute_mapping: Tensor,
    varlen_handler: VarlenHandler,
    local_attn_handler: FFAHandler | None,
    modality_dispatcher: ModalityDispatcher,
) -> Tensor

TransformerConfig dataclass

audio_in_channels instance-attribute

audio_in_channels: int

hidden_size instance-attribute

hidden_size: int

params_dtype instance-attribute

params_dtype: dtype

post_process_dtype instance-attribute

post_process_dtype: dtype

text_in_channels instance-attribute

text_in_channels: int

video_in_channels instance-attribute

video_in_channels: int

VarlenHandler dataclass

cu_seqlens_k instance-attribute

cu_seqlens_k: Tensor

cu_seqlens_q instance-attribute

cu_seqlens_q: Tensor

max_seqlen_k instance-attribute

max_seqlen_k: int

max_seqlen_q instance-attribute

max_seqlen_q: int

apply_rotary_emb_torch

apply_rotary_emb_torch(x, cos, sin, interleaved=False)

calc_local_attn_ffa_handler

calc_local_attn_ffa_handler(
    num_video_tokens,
    num_audio_and_txt_tokens,
    num_frames,
    frame_receptive_field,
)

calc_local_qk_range

calc_local_qk_range(
    num_video_tokens,
    num_audio_and_txt_tokens,
    num_frames,
    frame_receptive_field,
)

config_patch

config_patch(
    compile_config: CompileConfig,
) -> CompileConfig

create_activation_func

create_activation_func(
    activation_type: MLPActivationType,
) -> Callable

create_linear

create_linear(
    in_features,
    out_features,
    num_layers=1,
    num_experts=1,
    bias=True,
    device=None,
    dtype=None,
) -> BaseLinear | NativeMoELinear

flash_attn_func

flash_attn_func(
    query: Tensor, key: Tensor, value: Tensor
) -> Tensor

flash_attn_no_cp

flash_attn_no_cp(q: Tensor, k: Tensor, v: Tensor) -> Tensor

flex_flash_attn_func

flex_flash_attn_func(
    query: Tensor,
    key: Tensor,
    value: Tensor,
    q_ranges: Tensor,
    k_ranges: Tensor,
) -> tuple[Tensor, Tensor]

flex_flash_attn_no_cp

flex_flash_attn_no_cp(
    q: Tensor,
    k: Tensor,
    v: Tensor,
    q_ranges: Tensor,
    k_ranges: Tensor,
) -> Tensor

freq_bands

freq_bands(
    num_bands: int,
    temperature: float = 10000.0,
    step: int = 2,
    device: device | None = None,
) -> Tensor

gelu7

gelu7(
    x,
    alpha: float = 1.702,
    limit: float = 7.0,
    out_dtype: dtype | None = None,
)

get_coords

get_coords(
    shape: list[int],
    ref_feat_shape: list[int],
    offset_thw: list[int] | None = None,
    device: device = device("cpu"),
    dtype: dtype = float32,
)

magi_compile

magi_compile(*args, **kwargs)

No-op stub — vllm-omni handles execution; magi compilation is skipped.

rotate_half

rotate_half(x, interleaved=False)

swiglu7

swiglu7(
    x,
    alpha: float = 1.702,
    limit: float = 7.0,
    out_dtype: dtype | None = None,
)

validate_magi_human_tp_constraints

validate_magi_human_tp_constraints(
    *,
    hidden_size: int,
    num_heads_q: int,
    num_heads_kv: int,
    tensor_parallel_size: int,
) -> None

Validate MagiHuman TP divisibility constraints.

Both shared layers (num_modality == 1) and MoE layers (num_modality == 3) support TP via vLLM's parallel linear layers (QKVParallelLinear / ColumnParallelLinear / RowParallelLinear). MoE layers use per-expert parallel layers with modality dispatch.

Supported tp_sizes given default config (hidden=5120, heads_q=40, kv=8): 1, 2, 4.