Skip to content

vllm_omni.diffusion.models.cosyvoice3_audio.cosyvoice3_dit

logger module-attribute

logger = init_logger(__name__)

in notation: b - batch n - sequence nt - text sequence nw - raw wave length d - dimension

AdaLayerNormZero_Final

Bases: Module

AdaLayerNormZero for final layer - returns only modulated x.

linear instance-attribute

linear = nn.Linear(dim, dim * 2)

norm instance-attribute

norm = nn.LayerNorm(
    dim, elementwise_affine=False, eps=1e-06
)

silu instance-attribute

silu = nn.SiLU()

forward

forward(x, emb)

CausalConvPositionEmbedding

Bases: Module

Causal convolutional position embedding.

conv1 instance-attribute

conv1 = nn.Sequential(
    nn.Conv1d(
        dim, dim, kernel_size, groups=groups, padding=0
    ),
    nn.Mish(),
)

conv2 instance-attribute

conv2 = nn.Sequential(
    nn.Conv1d(
        dim, dim, kernel_size, groups=groups, padding=0
    ),
    nn.Mish(),
)

kernel_size instance-attribute

kernel_size = kernel_size

forward

forward(x: Tensor, mask: Tensor | None = None)

ConvNeXtV2Block

Bases: Module

ConvNeXt-V2 Block.

act instance-attribute

act = nn.GELU()

dwconv instance-attribute

dwconv = nn.Conv1d(
    dim,
    dim,
    kernel_size=7,
    padding=padding,
    groups=dim,
    dilation=dilation,
)

grn instance-attribute

grn = GRN(intermediate_dim)

norm instance-attribute

norm = nn.LayerNorm(dim, eps=1e-06)

pwconv1 instance-attribute

pwconv1 = nn.Linear(dim, intermediate_dim)

pwconv2 instance-attribute

pwconv2 = nn.Linear(intermediate_dim, dim)

forward

forward(x: Tensor) -> Tensor

DiT

Bases: Module

Diffusion Transformer backbone using optimized attention backends.

This is a drop-in replacement for the original DiT that uses the vllm_omni diffusion infrastructure for FlashAttention/SageAttention/SDPA.

depth instance-attribute

depth = depth

dim instance-attribute

dim = dim

input_embed instance-attribute

input_embed = InputEmbedding(mel_dim, mu_dim, dim, spk_dim)

long_skip_connection instance-attribute

long_skip_connection = (
    nn.Linear(dim * 2, dim, bias=False)
    if long_skip_connection
    else None
)

norm_out instance-attribute

norm_out = AdaLayerNormZero_Final(dim)

num_decoding_left_chunks instance-attribute

num_decoding_left_chunks = num_decoding_left_chunks

out_channels instance-attribute

out_channels = out_channels

proj_out instance-attribute

proj_out = nn.Linear(dim, mel_dim)

rotary_embed instance-attribute

rotary_embed = RotaryEmbedding(dim_head)

static_chunk_size instance-attribute

static_chunk_size = static_chunk_size

time_embed instance-attribute

time_embed = DiTTimestepEmbedding(dim)

transformer_blocks instance-attribute

transformer_blocks = nn.ModuleList(
    [
        (
            DiTBlock(
                dim=dim,
                heads=heads,
                dim_head=dim_head,
                ff_mult=ff_mult,
                dropout=dropout,
            )
        )
        for _ in (range(depth))
    ]
)

forward

forward(x, mask, mu, t, spks=None, cond=None)

DiTAttention

Bases: Module

Attention module using diffusion infrastructure for optimized backends.

This replaces the original Attention class to leverage FlashAttention, SageAttention, or SDPA backends automatically.

attn instance-attribute

attn = DiffusionAttention(
    num_heads=heads,
    head_size=dim_head,
    softmax_scale=self.scale,
    causal=False,
)

dim instance-attribute

dim = dim

dim_head instance-attribute

dim_head = dim_head

dropout instance-attribute

dropout = dropout

heads instance-attribute

heads = heads

inner_dim instance-attribute

inner_dim = dim_head * heads

scale instance-attribute

scale = 1.0 / math.sqrt(dim_head)

to_k instance-attribute

to_k = nn.Linear(dim, self.inner_dim)

to_out instance-attribute

to_out = nn.Sequential(
    nn.Linear(self.inner_dim, dim), nn.Dropout(dropout)
)

to_q instance-attribute

to_q = nn.Linear(dim, self.inner_dim)

to_v instance-attribute

to_v = nn.Linear(dim, self.inner_dim)

forward

forward(
    x: Tensor, mask: Tensor | None = None, rope=None
) -> Tensor

DiTBlock

Bases: Module

DiT block with AdaLayerNorm modulation.

attn instance-attribute

attn = DiTAttention(
    dim=dim, heads=heads, dim_head=dim_head, dropout=dropout
)

attn_norm instance-attribute

attn_norm = AdaLayerNormZero(dim)

ff instance-attribute

ff = FeedForward(
    dim=dim,
    mult=ff_mult,
    dropout=dropout,
    approximate="tanh",
)

ff_norm instance-attribute

ff_norm = nn.LayerNorm(
    dim, elementwise_affine=False, eps=1e-06
)

forward

forward(x, t, mask=None, rope=None)

FeedForward

Bases: Module

Feed-forward network with GELU activation.

ff instance-attribute

ff = nn.Sequential(
    project_in,
    nn.Dropout(dropout),
    nn.Linear(inner_dim, dim_out),
)

forward

forward(x)

GRN

Bases: Module

Global Response Normalization layer.

beta instance-attribute

beta = nn.Parameter(torch.zeros(1, 1, dim))

gamma instance-attribute

gamma = nn.Parameter(torch.zeros(1, 1, dim))

forward

forward(x)

InputEmbedding

Bases: Module

Input embedding combining noised audio, condition, text, and speaker.

conv_pos_embed instance-attribute

conv_pos_embed = CausalConvPositionEmbedding(dim=out_dim)

proj instance-attribute

proj = nn.Linear(mel_dim * 2 + text_dim + spk_dim, out_dim)

spk_dim instance-attribute

spk_dim = spk_dim

forward

forward(x, cond, text_embed, spks)

TextEmbedding

Bases: Module

Text embedding with optional ConvNeXt modeling.

extra_modeling instance-attribute

extra_modeling = True

precompute_max_pos instance-attribute

precompute_max_pos = 4096

text_blocks instance-attribute

text_blocks = nn.Sequential(
    *[
        (ConvNeXtV2Block(text_dim, text_dim * conv_mult))
        for _ in (range(conv_layers))
    ]
)

text_embed instance-attribute

text_embed = nn.Embedding(text_num_embeds + 1, text_dim)

forward

forward(text: Tensor, seq_len, drop_text=False)

get_pos_embed_indices

get_pos_embed_indices(start, length, max_pos, scale=1.0)

Get position embedding indices.

precompute_freqs_cis

precompute_freqs_cis(
    dim: int,
    end: int,
    theta: float = 10000.0,
    theta_rescale_factor=1.0,
)

Precompute rotary embedding frequencies.