Skip to content

vllm_omni.diffusion.models.audiox.audiox_transformer

logger module-attribute

logger = getLogger(__name__)

AudioXCrossAttention

Bases: Module

attn instance-attribute

attn = Attention(
    num_heads=local_nheads,
    head_size=head_dim,
    softmax_scale=head_dim**-0.5,
    causal=False,
)

dim instance-attribute

dim = dim

head_dim instance-attribute

head_dim = head_dim

k_norm instance-attribute

k_norm = AudioXRMSNorm(head_dim)

nheads instance-attribute

nheads = nheads

q_norm instance-attribute

q_norm = AudioXRMSNorm(head_dim)

to_kv instance-attribute

to_kv = MergedColumnParallelLinear(
    input_size=dim,
    output_sizes=[dim, dim],
    bias=False,
    gather_output=False,
    prefix=f"{prefix}.to_kv",
)

to_q instance-attribute

to_q = ColumnParallelLinear(
    dim,
    dim,
    bias=False,
    gather_output=False,
    prefix=f"{prefix}.to_q",
)

forward

forward(x: Tensor, context: Tensor | None = None) -> Tensor

AudioXMMChannelLastConv1d

Bases: Conv1d

forward

forward(x: Tensor) -> Tensor

AudioXMMConvFeedForward

Bases: Module

w1 instance-attribute

w1 = _ColumnParallelChannelLastConv1d(
    dim,
    hidden_dim,
    bias=False,
    kernel_size=kernel_size,
    padding=padding,
)

w2 instance-attribute

w2 = _RowParallelChannelLastConv1d(
    hidden_dim,
    dim,
    bias=False,
    kernel_size=kernel_size,
    padding=padding,
)

w3 instance-attribute

w3 = _ColumnParallelChannelLastConv1d(
    dim,
    hidden_dim,
    bias=False,
    kernel_size=kernel_size,
    padding=padding,
)

forward

forward(x)

AudioXMMDiTBlock

Bases: Module

adaLN_modulation instance-attribute

adaLN_modulation = Sequential(
    SiLU(), Linear(dim, 6 * dim, bias=True)
)

attn instance-attribute

attn = AudioXMMDiTSelfAttention(
    dim, nhead, prefix=f"{prefix}.attn"
)

cross_attn instance-attribute

cross_attn = AudioXCrossAttention(
    dim, nhead, prefix=f"{prefix}.cross_attn"
)

ffn instance-attribute

ffn = AudioXMMConvFeedForward(
    dim, int(dim * mlp_ratio), kernel_size=3, padding=1
)

linear1 instance-attribute

linear1 = AudioXMMChannelLastConv1d(
    dim, dim, kernel_size=3, padding=1
)

norm1 instance-attribute

norm1 = LayerNorm(dim, elementwise_affine=False)

norm2 instance-attribute

norm2 = LayerNorm(dim, elementwise_affine=False)

forward

forward(
    x: Tensor,
    cond: Tensor,
    rot: tuple[Tensor, Tensor] | None,
    context: Tensor = None,
) -> Tensor

post_attention

post_attention(
    x: Tensor,
    attn_out: Tensor,
    c: tuple[Tensor, ...],
    context=None,
)

pre_attention

pre_attention(
    x: Tensor, c: Tensor, rot: tuple[Tensor, Tensor] | None
)

AudioXMMDiTSelfAttention

Bases: Module

attn instance-attribute

attn = Attention(
    num_heads=num_heads,
    head_size=head_dim,
    softmax_scale=head_dim**-0.5,
    causal=False,
)

dim instance-attribute

dim = dim

head_dim instance-attribute

head_dim = head_dim

k_norm instance-attribute

k_norm = AudioXRMSNorm(head_dim)

nheads instance-attribute

nheads = nheads

q_norm instance-attribute

q_norm = AudioXRMSNorm(head_dim)

qkv instance-attribute

qkv = QKVParallelLinear(
    hidden_size=dim,
    head_size=head_dim,
    total_num_heads=nheads,
    bias=True,
    prefix=f"{prefix}.qkv",
)

rope instance-attribute

rope = RotaryEmbedding(is_neox_style=False)

apply_attention

apply_attention(q: Tensor, k: Tensor, v: Tensor) -> Tensor

forward

forward(
    x: Tensor, rot: tuple[Tensor, Tensor] | None = None
) -> Tensor

pre_attention

pre_attention(
    x: Tensor, rot: tuple[Tensor, Tensor] | None = None
)

AudioXRMSNorm

Bases: Module

eps instance-attribute

eps = eps

forward

forward(x: Tensor) -> Tensor

ContinuousMMDiTTransformer

Bases: Module

depth instance-attribute

depth = depth

device property

device

dim instance-attribute

dim = dim

layers instance-attribute

layers = ModuleList(
    [
        (
            AudioXMMDiTBlock(
                hidden_dim,
                num_heads,
                mlp_ratio=mlp_ratio,
                prefix=f"layers.{i}",
            )
        )
        for i in (range(depth))
    ]
)

proj_mm_seq_len instance-attribute

proj_mm_seq_len = (
    Linear(384, _latent_seq_len)
    if _latent_seq_len != 384
    else Identity()
)

proj_mm_tokens instance-attribute

proj_mm_tokens = (
    Linear(768, hidden_dim) if dim != 768 else Identity()
)

project_in instance-attribute

project_in = (
    Linear(dim_in, dim, bias=False)
    if dim_in is not None
    else Identity()
)

project_out instance-attribute

project_out = (
    Linear(dim, dim_out, bias=False)
    if dim_out is not None
    else Identity()
)

forward

forward(x, prepend_embeds=None, context=None)

MMDiffusionTransformer

Bases: Module

AudioX MMDiT, specialized for the published bundle (zhangj1an/AudioX).

The bundle fixes patch_size=1, transformer_type="continuous_transformer", cond_token_dim=768 (>0, project_cond_tokens=False), and never sets prepend_cond_dim or input_concat_dim, so those code paths are removed.

cond_token_dim instance-attribute

cond_token_dim = cond_token_dim

postprocess_conv instance-attribute

postprocess_conv = Conv1d(
    io_channels, io_channels, 1, bias=False
)

preprocess_conv instance-attribute

preprocess_conv = Conv1d(
    io_channels, io_channels, 1, bias=False
)

timestep_features instance-attribute

timestep_features = GaussianFourierProjection(
    in_features=1,
    embedding_size=timestep_features_dim // 2,
    scale=1.0,
    trainable=False,
)

to_cond_embed instance-attribute

to_cond_embed = Sequential(
    Linear(cond_token_dim, cond_embed_dim, bias=False),
    SiLU(),
    Linear(cond_embed_dim, cond_embed_dim, bias=False),
)

to_global_embed instance-attribute

to_global_embed = Sequential(
    Linear(global_cond_dim, global_embed_dim, bias=False),
    SiLU(),
    Linear(global_embed_dim, global_embed_dim, bias=False),
)

to_timestep_embed instance-attribute

to_timestep_embed = Sequential(
    Linear(timestep_features_dim, embed_dim, bias=True),
    SiLU(),
    Linear(embed_dim, embed_dim, bias=True),
)

transformer instance-attribute

transformer = ContinuousMMDiTTransformer(
    dim=embed_dim,
    depth=depth,
    dim_heads=embed_dim // num_heads,
    dim_in=io_channels,
    dim_out=io_channels,
)

forward

forward(
    x,
    t,
    cross_attn_cond,
    negative_cross_attn_cond=None,
    negative_cross_attn_mask=None,
    cfg_scale: float = 1.0,
    scale_phi: float = 0.0,
    **kwargs,
)

load_weights

load_weights(
    weights: Iterable[tuple[str, Tensor]],
) -> set[str]