Skip to content

vllm_omni.diffusion.models.ltx2.ltx2_transformer

logger module-attribute

logger = init_logger(__name__)

AudioVisualModelOutput dataclass

Bases: BaseOutput

Holds the output of an audiovisual model which produces both visual (e.g. video) and audio outputs.

Parameters:

Name Type Description Default
sample `torch.Tensor` of shape `(batch_size, num_video_tokens, out_channels)`

The patchified visual output conditioned on the encoder_hidden_states input. This is the transformer output before the pipeline unpacks it back into video latent dimensions.

required
audio_sample `torch.Tensor` of shape `(batch_size, num_audio_tokens, audio_out_channels)`

The patchified audio output of the audiovisual model before the pipeline unpacks it back into audio latent dimensions.

required

audio_sample instance-attribute

audio_sample: Tensor

sample instance-attribute

sample: Tensor

ColumnParallelApproxGELU

Bases: Module

approximate instance-attribute

approximate = approximate

proj instance-attribute

proj = ColumnParallelLinear(
    dim_in,
    dim_out,
    bias=bias,
    gather_output=False,
    return_bias=False,
    quant_config=quant_config,
    prefix=f"{prefix}.proj" if prefix else "proj",
)

forward

forward(x: Tensor) -> Tensor

LTX2AdaLayerNormSingle

Bases: Module

Norm layer adaptive layer norm single (adaLN-single).

As proposed in PixArt-Alpha (see: https://huggingface.co/papers/2310.00426; Section 2.3) and adapted by the LTX-2.0 model. In particular, the number of modulation parameters to be calculated is now configurable.

Parameters:

Name Type Description Default
embedding_dim `int`

The size of each embedding vector.

required
num_mod_params `int`, *optional*, defaults to `6`

The number of modulation parameters which will be calculated in the first return argument. The default of 6 is standard, but sometimes we may want to have a different (usually smaller) number of modulation parameters.

6
use_additional_conditions `bool`, *optional*, defaults to `False`

Whether to use additional conditions for normalization or not.

False

emb instance-attribute

emb = PixArtAlphaCombinedTimestepSizeEmbeddings(
    embedding_dim,
    size_emb_dim=embedding_dim // 3,
    use_additional_conditions=use_additional_conditions,
)

linear instance-attribute

linear = Linear(
    embedding_dim, num_mod_params * embedding_dim, bias=True
)

num_mod_params instance-attribute

num_mod_params = num_mod_params

silu instance-attribute

silu = SiLU()

forward

forward(
    timestep: Tensor,
    added_cond_kwargs: dict[str, Tensor] | None = None,
    batch_size: int | None = None,
    hidden_dtype: dtype | None = None,
) -> tuple[Tensor, Tensor, Tensor, Tensor, Tensor]

LTX2Attention

Bases: Module

Attention class for all LTX-2.0 attention layers. Compared to LTX-1.0, this supports specifying the query and key RoPE embeddings separately for audio-to-video (a2v) and video-to-audio (v2a) cross-attention.

attn instance-attribute

attn = Attention(
    num_heads=query_num_heads,
    head_size=dim_head,
    num_kv_heads=kv_num_heads,
    softmax_scale=1.0 / dim_head**0.5,
    causal=False,
    prefix=prefix,
    disable_kv_quant=disable_kv_quant,
)

cross_attention_dim instance-attribute

cross_attention_dim = (
    cross_attention_dim
    if cross_attention_dim is not None
    else query_dim
)

dropout instance-attribute

dropout = dropout

head_dim instance-attribute

head_dim = dim_head

heads instance-attribute

heads = query_num_heads

inner_dim instance-attribute

inner_dim = dim_head * heads

inner_kv_dim instance-attribute

inner_kv_dim = dim_head * kv_heads

kv_num_heads instance-attribute

kv_num_heads = num_kv_heads

norm_k instance-attribute

norm_k = TensorParallelRMSNorm(
    dim_head * kv_num_heads,
    eps=norm_eps,
    elementwise_affine=norm_elementwise_affine,
    tp_size=tp_size,
)

norm_q instance-attribute

norm_q = TensorParallelRMSNorm(
    dim_head * query_num_heads,
    eps=norm_eps,
    elementwise_affine=norm_elementwise_affine,
    tp_size=tp_size,
)

out_dim instance-attribute

out_dim = query_dim

query_dim instance-attribute

query_dim = query_dim

query_num_heads instance-attribute

query_num_heads = num_heads

rope_type instance-attribute

rope_type = rope_type

to_gate_logits instance-attribute

to_gate_logits = Linear(query_dim, heads, bias=True)

to_k instance-attribute

to_k = None

to_out instance-attribute

to_out = ModuleList(
    [
        RowParallelLinear(
            inner_dim,
            out_dim,
            bias=out_bias,
            input_is_parallel=True,
            return_bias=False,
            quant_config=quant_config,
            prefix=f"{prefix}.to_out.0"
            if prefix
            else "to_out.0",
        ),
        Dropout(dropout) if dropout > 0 else Identity(),
    ]
)

to_q instance-attribute

to_q = None

to_qkv instance-attribute

to_qkv = None

to_v instance-attribute

to_v = None

total_num_heads instance-attribute

total_num_heads = heads

total_num_kv_heads instance-attribute

total_num_kv_heads = kv_heads

use_bias instance-attribute

use_bias = bias

forward

forward(
    hidden_states: Tensor,
    encoder_hidden_states: Tensor | None = None,
    attention_mask: Tensor | None = None,
    query_rotary_emb: tuple[Tensor, Tensor] | None = None,
    key_rotary_emb: tuple[Tensor, Tensor] | None = None,
    **kwargs,
) -> Tensor

get_processor

get_processor() -> Any

prepare_attention_mask

prepare_attention_mask(
    attention_mask: Tensor | None,
    target_length: int,
    batch_size: int,
    out_dim: int = 3,
) -> Tensor | None

set_processor

set_processor(processor: Any) -> None

LTX2AudioVideoAttnProcessor

Processor for implementing attention (SDPA is used by default if you're using PyTorch 2.0) for the LTX-2.0 model. Compared to the LTX-1.0 model, we allow the RoPE embeddings for the queries and keys to be separate so that we can support audio-to-video (a2v) and video-to-audio (v2a) cross attention.

LTX2AudioVideoRotaryPosEmbed

Bases: Module

Video and audio rotary positional embeddings (RoPE) for the LTX-2.0 model.

Parameters:

Name Type Description Default
causal_offset `int`, *optional*, defaults to `1`

Offset in the temporal axis for causal VAE modeling. This is typically 1 (for causal modeling where the VAE treats the very first frame differently), but could also be 0 (for non-causal modeling).

1

audio_latents_per_second instance-attribute

audio_latents_per_second = (
    float(sampling_rate)
    / float(hop_length)
    / float(scale_factors[0])
)

base_height instance-attribute

base_height = base_height

base_num_frames instance-attribute

base_num_frames = base_num_frames

base_width instance-attribute

base_width = base_width

causal_offset instance-attribute

causal_offset = causal_offset

dim instance-attribute

dim = dim

double_precision instance-attribute

double_precision = double_precision

hop_length instance-attribute

hop_length = hop_length

modality instance-attribute

modality = modality

num_attention_heads instance-attribute

num_attention_heads = num_attention_heads

patch_size instance-attribute

patch_size = patch_size

patch_size_t instance-attribute

patch_size_t = patch_size_t

rope_type instance-attribute

rope_type = rope_type

sampling_rate instance-attribute

sampling_rate = sampling_rate

scale_factors instance-attribute

scale_factors = scale_factors

theta instance-attribute

theta = theta

forward

forward(
    coords: Tensor, device: str | device | None = None
) -> tuple[Tensor, Tensor]

prepare_audio_coords

prepare_audio_coords(
    batch_size: int,
    num_frames: int,
    device: device,
    shift: int = 0,
) -> Tensor

Create per-dimension bounds [inclusive start, exclusive end) of start and end timestamps for each latent frame. This will ultimately have shape (batch_size, 3, num_patches, 2) where - axis 1 (size 1) represents the temporal dimension - axis 3 (size 2) stores [start, end) indices within each dimension

Parameters:

Name Type Description Default
batch_size `int`

Batch size of the audio latents.

required
num_frames `int`

Number of latent frames in the audio latents.

required
device `torch.device`

Device on which to create the audio grid.

required
shift `int`, *optional*, defaults to `0`

Offset on the latent indices. Different shift values correspond to different overlapping windows with respect to the same underlying latent grid.

0

Returns:

Type Description
Tensor

torch.Tensor: Per-dimension patch boundaries tensor of shape [batch_size, 1, num_patches, 2].

prepare_coords

prepare_coords(*args, **kwargs)

prepare_video_coords

prepare_video_coords(
    batch_size: int,
    num_frames: int,
    height: int,
    width: int,
    device: device,
    fps: float = 24.0,
) -> Tensor

Create per-dimension bounds [inclusive start, exclusive end) for each patch with respect to the original pixel space video grid (num_frames, height, width). This will ultimately have shape (batch_size, 3, num_patches, 2) where - axis 1 (size 3) enumerates (frame, height, width) dimensions (e.g. idx 0 corresponds to frames) - axis 3 (size 2) stores [start, end) indices within each dimension

Parameters:

Name Type Description Default
batch_size `int`

Batch size of the video latents.

required
num_frames `int`

Number of latent frames in the video latents.

required
height `int`

Latent height of the video latents.

required
width `int`

Latent width of the video latents.

required
device `torch.device`

Device on which to create the video grid.

required

Returns:

Type Description
Tensor

torch.Tensor: Per-dimension patch boundaries tensor of shape [batch_size, 3, num_patches, 2].

LTX2FeedForward

Bases: Module

net instance-attribute

net = ModuleList(layers)

forward

forward(hidden_states: Tensor) -> Tensor

LTX2VideoTransformer3DModel

Bases: Module

A Transformer model for video-like data used in LTX.

Parameters:

Name Type Description Default
in_channels `int`, defaults to `128`

The number of channels in the input.

128
out_channels `int`, defaults to `128`

The number of channels in the output.

128
patch_size `int`, defaults to `1`

The size of the spatial patches to use in the patch embedding layer.

1
patch_size_t `int`, defaults to `1`

The size of the tmeporal patches to use in the patch embedding layer.

1
num_attention_heads `int`, defaults to `32`

The number of heads to use for multi-head attention.

32
attention_head_dim `int`, defaults to `64`

The number of channels in each head.

128
cross_attention_dim `int`, defaults to `2048 `

The number of channels for cross attention heads.

4096
num_layers `int`, defaults to `28`

The number of layers of Transformer blocks to use.

48
activation_fn `str`, defaults to `"gelu-approximate"`

Activation function to use in feed-forward.

'gelu-approximate'
qk_norm `str`, defaults to `"rms_norm_across_heads"`

The normalization layer to use.

'rms_norm_across_heads'

audio_caption_projection instance-attribute

audio_caption_projection = PixArtAlphaTextProjection(
    in_features=caption_channels,
    hidden_size=audio_inner_dim,
)

audio_norm_out instance-attribute

audio_norm_out = LayerNorm(
    audio_inner_dim, eps=1e-06, elementwise_affine=False
)

audio_proj_in instance-attribute

audio_proj_in = Linear(audio_in_channels, audio_inner_dim)

audio_proj_out instance-attribute

audio_proj_out = Linear(audio_inner_dim, audio_out_channels)

audio_prompt_adaln instance-attribute

audio_prompt_adaln = LTX2AdaLayerNormSingle(
    audio_inner_dim,
    num_mod_params=2,
    use_additional_conditions=False,
)

audio_rope instance-attribute

audio_rope = LTX2AudioVideoRotaryPosEmbed(
    dim=audio_inner_dim,
    patch_size=audio_patch_size,
    patch_size_t=audio_patch_size_t,
    base_num_frames=audio_pos_embed_max_pos,
    sampling_rate=audio_sampling_rate,
    hop_length=audio_hop_length,
    scale_factors=[audio_scale_factor],
    theta=rope_theta,
    causal_offset=causal_offset,
    modality="audio",
    double_precision=rope_double_precision,
    rope_type=rope_type,
    num_attention_heads=audio_num_attention_heads,
)

audio_scale_shift_table instance-attribute

audio_scale_shift_table = Parameter(
    randn(2, audio_inner_dim) / audio_inner_dim**0.5
)

audio_time_embed instance-attribute

audio_time_embed = LTX2AdaLayerNormSingle(
    audio_inner_dim,
    num_mod_params=audio_num_mod_params,
    use_additional_conditions=False,
)

av_cross_attn_audio_scale_shift instance-attribute

av_cross_attn_audio_scale_shift = LTX2AdaLayerNormSingle(
    audio_inner_dim,
    num_mod_params=4,
    use_additional_conditions=False,
)

av_cross_attn_audio_v2a_gate instance-attribute

av_cross_attn_audio_v2a_gate = LTX2AdaLayerNormSingle(
    audio_inner_dim,
    num_mod_params=1,
    use_additional_conditions=False,
)

av_cross_attn_video_a2v_gate instance-attribute

av_cross_attn_video_a2v_gate = LTX2AdaLayerNormSingle(
    inner_dim,
    num_mod_params=1,
    use_additional_conditions=False,
)

av_cross_attn_video_scale_shift instance-attribute

av_cross_attn_video_scale_shift = LTX2AdaLayerNormSingle(
    inner_dim,
    num_mod_params=4,
    use_additional_conditions=False,
)

caption_projection instance-attribute

caption_projection = PixArtAlphaTextProjection(
    in_features=caption_channels, hidden_size=inner_dim
)

config instance-attribute

config = SimpleNamespace(
    in_channels=in_channels,
    out_channels=out_channels,
    patch_size=patch_size,
    patch_size_t=patch_size_t,
    num_attention_heads=num_attention_heads,
    attention_head_dim=attention_head_dim,
    cross_attention_dim=cross_attention_dim,
    vae_scale_factors=vae_scale_factors,
    pos_embed_max_pos=pos_embed_max_pos,
    base_height=base_height,
    base_width=base_width,
    audio_in_channels=audio_in_channels,
    audio_out_channels=audio_out_channels,
    audio_patch_size=audio_patch_size,
    audio_patch_size_t=audio_patch_size_t,
    audio_num_attention_heads=audio_num_attention_heads,
    audio_attention_head_dim=audio_attention_head_dim,
    audio_cross_attention_dim=audio_cross_attention_dim,
    audio_scale_factor=audio_scale_factor,
    audio_pos_embed_max_pos=audio_pos_embed_max_pos,
    audio_sampling_rate=audio_sampling_rate,
    audio_hop_length=audio_hop_length,
    num_layers=num_layers,
    activation_fn=activation_fn,
    qk_norm=qk_norm,
    norm_elementwise_affine=norm_elementwise_affine,
    norm_eps=norm_eps,
    caption_channels=caption_channels,
    attention_bias=attention_bias,
    attention_out_bias=attention_out_bias,
    rope_theta=rope_theta,
    rope_double_precision=rope_double_precision,
    causal_offset=causal_offset,
    timestep_scale_multiplier=timestep_scale_multiplier,
    cross_attn_timestep_scale_multiplier=cross_attn_timestep_scale_multiplier,
    rope_type=rope_type,
)

cross_attn_audio_rope instance-attribute

cross_attn_audio_rope = LTX2AudioVideoRotaryPosEmbed(
    dim=audio_cross_attention_dim,
    patch_size=audio_patch_size,
    patch_size_t=audio_patch_size_t,
    base_num_frames=cross_attn_pos_embed_max_pos,
    sampling_rate=audio_sampling_rate,
    hop_length=audio_hop_length,
    theta=rope_theta,
    causal_offset=causal_offset,
    modality="audio",
    double_precision=rope_double_precision,
    rope_type=rope_type,
    num_attention_heads=audio_num_attention_heads,
)

cross_attn_rope instance-attribute

cross_attn_rope = LTX2AudioVideoRotaryPosEmbed(
    dim=audio_cross_attention_dim,
    patch_size=patch_size,
    patch_size_t=patch_size_t,
    base_num_frames=cross_attn_pos_embed_max_pos,
    base_height=base_height,
    base_width=base_width,
    theta=rope_theta,
    causal_offset=causal_offset,
    modality="video",
    double_precision=rope_double_precision,
    rope_type=rope_type,
    num_attention_heads=num_attention_heads,
)

gradient_checkpointing instance-attribute

gradient_checkpointing = False

norm_out instance-attribute

norm_out = LayerNorm(
    inner_dim, eps=1e-06, elementwise_affine=False
)

packed_modules_mapping class-attribute instance-attribute

packed_modules_mapping = {
    "to_qkv": ["to_q", "to_k", "to_v"]
}

perturbed_attn instance-attribute

perturbed_attn = perturbed_attn

proj_in instance-attribute

proj_in = Linear(in_channels, inner_dim)

proj_out instance-attribute

proj_out = Linear(inner_dim, out_channels)

prompt_adaln instance-attribute

prompt_adaln = LTX2AdaLayerNormSingle(
    inner_dim,
    num_mod_params=2,
    use_additional_conditions=False,
)

prompt_modulation instance-attribute

prompt_modulation = cross_attn_mod or audio_cross_attn_mod

rope instance-attribute

rope = LTX2AudioVideoRotaryPosEmbed(
    dim=inner_dim,
    patch_size=patch_size,
    patch_size_t=patch_size_t,
    base_num_frames=pos_embed_max_pos,
    base_height=base_height,
    base_width=base_width,
    scale_factors=vae_scale_factors,
    theta=rope_theta,
    causal_offset=causal_offset,
    modality="video",
    double_precision=rope_double_precision,
    rope_type=rope_type,
    num_attention_heads=num_attention_heads,
)

scale_shift_table instance-attribute

scale_shift_table = Parameter(
    randn(2, inner_dim) / inner_dim**0.5
)

time_embed instance-attribute

time_embed = LTX2AdaLayerNormSingle(
    inner_dim,
    num_mod_params=video_num_mod_params,
    use_additional_conditions=False,
)

transformer_blocks instance-attribute

transformer_blocks = ModuleList(
    [
        (
            LTX2VideoTransformerBlock(
                dim=inner_dim,
                num_attention_heads=num_attention_heads,
                attention_head_dim=attention_head_dim,
                cross_attention_dim=cross_attention_dim,
                audio_dim=audio_inner_dim,
                audio_num_attention_heads=audio_num_attention_heads,
                audio_attention_head_dim=audio_attention_head_dim,
                audio_cross_attention_dim=audio_cross_attention_dim,
                video_gated_attn=gated_attn,
                video_cross_attn_adaln=cross_attn_mod,
                audio_gated_attn=audio_gated_attn,
                audio_cross_attn_adaln=audio_cross_attn_mod,
                qk_norm=qk_norm,
                activation_fn=activation_fn,
                attention_bias=attention_bias,
                attention_out_bias=attention_out_bias,
                eps=norm_eps,
                elementwise_affine=norm_elementwise_affine,
                rope_type=rope_type,
                perturbed_attn=perturbed_attn,
                quant_config=quant_config,
                prefix=f"transformer_blocks.{layer_idx}",
            )
        )
        for layer_idx in (range(num_layers))
    ]
)

disable_gradient_checkpointing

disable_gradient_checkpointing() -> None

enable_gradient_checkpointing

enable_gradient_checkpointing() -> None

forward

forward(
    hidden_states: Tensor,
    audio_hidden_states: Tensor,
    encoder_hidden_states: Tensor,
    audio_encoder_hidden_states: Tensor,
    timestep: LongTensor,
    audio_timestep: LongTensor | None = None,
    sigma: Tensor | None = None,
    audio_sigma: Tensor | None = None,
    encoder_attention_mask: Tensor | None = None,
    audio_encoder_attention_mask: Tensor | None = None,
    num_frames: int | None = None,
    height: int | None = None,
    width: int | None = None,
    fps: float = 24.0,
    audio_num_frames: int | None = None,
    video_coords: Tensor | None = None,
    audio_coords: Tensor | None = None,
    attention_kwargs: dict[str, Any] | None = None,
    return_dict: bool = True,
    **kwargs,
) -> Tensor

Forward pass for LTX-2.0 audiovisual video transformer.

Parameters:

Name Type Description Default
hidden_states `torch.Tensor`

Input patchified video latents of shape (batch_size, num_video_tokens, in_channels).

required
audio_hidden_states `torch.Tensor`

Input patchified audio latents of shape (batch_size, num_audio_tokens, audio_in_channels).

required
encoder_hidden_states `torch.Tensor`

Input video text embeddings of shape (batch_size, text_seq_len, self.config.caption_channels).

required
audio_encoder_hidden_states `torch.Tensor`

Input audio text embeddings of shape (batch_size, text_seq_len, self.config.caption_channels).

required
timestep `torch.Tensor`

Input timestep of shape (batch_size, num_video_tokens). These should already be scaled by self.config.timestep_scale_multiplier.

required
audio_timestep `torch.Tensor`, *optional*

Input timestep of shape (batch_size,) or (batch_size, num_audio_tokens) for audio modulation params. This is only used by certain pipelines such as the I2V pipeline.

None
encoder_attention_mask `torch.Tensor`, *optional*

Optional multiplicative text attention mask of shape (batch_size, text_seq_len).

None
audio_encoder_attention_mask `torch.Tensor`, *optional*

Optional multiplicative text attention mask of shape (batch_size, text_seq_len) for audio modeling.

None
num_frames `int`, *optional*

The number of latent video frames. Used if calculating the video coordinates for RoPE.

None
height `int`, *optional*

The latent video height. Used if calculating the video coordinates for RoPE.

None
width `int`, *optional*

The latent video width. Used if calculating the video coordinates for RoPE.

None
fps float

(float, optional, defaults to 24.0): The desired frames per second of the generated video. Used if calculating the video coordinates for RoPE.

24.0
audio_num_frames int | None

(int, optional): The number of latent audio frames. Used if calculating the audio coordinates for RoPE.

None
video_coords `torch.Tensor`, *optional*

The video coordinates to be used when calculating the rotary positional embeddings (RoPE) of shape (batch_size, 3, num_video_tokens, 2). If not supplied, this will be calculated inside forward.

None
audio_coords `torch.Tensor`, *optional*

The audio coordinates to be used when calculating the rotary positional embeddings (RoPE) of shape (batch_size, 1, num_audio_tokens, 2). If not supplied, this will be calculated inside forward.

None
attention_kwargs `Dict[str, Any]`, *optional*

Optional dict of keyword args to be passed to the attention processor.

None
return_dict `bool`, *optional*, defaults to `True`

Whether to return a dict-like structured output of type AudioVisualModelOutput or a tuple.

True

Returns:

Type Description
Tensor

AudioVisualModelOutput or tuple: If return_dict is True, returns a structured output of type AudioVisualModelOutput, otherwise a tuple is returned where the first element is the denoised video latent patch sequence and the second element is the denoised audio latent patch sequence.

load_weights

load_weights(
    weights: Iterable[tuple[str, Tensor]],
) -> set[str]

Load weights from a pretrained model, mapping separate Q/K/V projections into fused QKV projections for self-attention blocks.

Returns:

Type Description
set[str]

Set of parameter names that were successfully loaded.

LTX2VideoTransformerBlock

Bases: Module

Transformer block used in LTX-2.0.

Parameters:

Name Type Description Default
dim `int`

The number of channels in the input and output.

required
num_attention_heads `int`

The number of heads to use for multi-head attention.

required
attention_head_dim `int`

The number of channels in each head.

required
qk_norm `str`, defaults to `"rms_norm"`

The normalization layer to use.

'rms_norm_across_heads'
activation_fn `str`, defaults to `"gelu-approximate"`

Activation function to use in feed-forward.

'gelu-approximate'
eps `float`, defaults to `1e-6`

Epsilon value for normalization layers.

1e-06

attn1 instance-attribute

attn1 = LTX2Attention(
    query_dim=dim,
    heads=num_attention_heads,
    kv_heads=num_attention_heads,
    dim_head=attention_head_dim,
    bias=attention_bias,
    cross_attention_dim=None,
    out_bias=attention_out_bias,
    qk_norm=qk_norm,
    rope_type=rope_type,
    apply_gated_attention=video_gated_attn,
    quant_config=quant_config,
    prefix=f"{prefix}.attn1" if prefix else "attn1",
)

attn2 instance-attribute

attn2 = LTX2Attention(
    query_dim=dim,
    cross_attention_dim=cross_attention_dim,
    heads=num_attention_heads,
    kv_heads=num_attention_heads,
    dim_head=attention_head_dim,
    bias=attention_bias,
    out_bias=attention_out_bias,
    qk_norm=qk_norm,
    rope_type=rope_type,
    apply_gated_attention=video_gated_attn,
    quant_config=quant_config,
    prefix=f"{prefix}.attn2" if prefix else "attn2",
    disable_kv_quant=True,
)

audio_a2v_cross_attn_scale_shift_table instance-attribute

audio_a2v_cross_attn_scale_shift_table = Parameter(
    randn(5, audio_dim)
)

audio_attn1 instance-attribute

audio_attn1 = LTX2Attention(
    query_dim=audio_dim,
    heads=audio_num_attention_heads,
    kv_heads=audio_num_attention_heads,
    dim_head=audio_attention_head_dim,
    bias=attention_bias,
    cross_attention_dim=None,
    out_bias=attention_out_bias,
    qk_norm=qk_norm,
    rope_type=rope_type,
    apply_gated_attention=audio_gated_attn,
    quant_config=quant_config,
    prefix=f"{prefix}.audio_attn1"
    if prefix
    else "audio_attn1",
)

audio_attn2 instance-attribute

audio_attn2 = LTX2Attention(
    query_dim=audio_dim,
    cross_attention_dim=audio_cross_attention_dim,
    heads=audio_num_attention_heads,
    kv_heads=audio_num_attention_heads,
    dim_head=audio_attention_head_dim,
    bias=attention_bias,
    out_bias=attention_out_bias,
    qk_norm=qk_norm,
    rope_type=rope_type,
    apply_gated_attention=audio_gated_attn,
    quant_config=quant_config,
    prefix=f"{prefix}.audio_attn2"
    if prefix
    else "audio_attn2",
    disable_kv_quant=True,
)

audio_cross_attn_adaln instance-attribute

audio_cross_attn_adaln = audio_cross_attn_adaln

audio_ff instance-attribute

audio_ff = LTX2FeedForward(
    audio_dim,
    activation_fn=activation_fn,
    quant_config=quant_config,
    prefix=f"{prefix}.audio_ff" if prefix else "audio_ff",
)

audio_norm1 instance-attribute

audio_norm1 = _make_rms_norm(
    audio_dim,
    eps=eps,
    elementwise_affine=elementwise_affine,
)

audio_norm2 instance-attribute

audio_norm2 = _make_rms_norm(
    audio_dim,
    eps=eps,
    elementwise_affine=elementwise_affine,
)

audio_norm3 instance-attribute

audio_norm3 = _make_rms_norm(
    audio_dim,
    eps=eps,
    elementwise_affine=elementwise_affine,
)

audio_prompt_scale_shift_table instance-attribute

audio_prompt_scale_shift_table = Parameter(
    randn(2, audio_dim)
)

audio_scale_shift_table instance-attribute

audio_scale_shift_table = Parameter(
    randn(audio_mod_param_num, audio_dim) / audio_dim**0.5
)

audio_to_video_attn instance-attribute

audio_to_video_attn = LTX2Attention(
    query_dim=dim,
    cross_attention_dim=audio_dim,
    heads=audio_num_attention_heads,
    kv_heads=audio_num_attention_heads,
    dim_head=audio_attention_head_dim,
    bias=attention_bias,
    out_bias=attention_out_bias,
    qk_norm=qk_norm,
    rope_type=rope_type,
    apply_gated_attention=video_gated_attn,
    quant_config=quant_config,
    prefix=f"{prefix}.audio_to_video_attn"
    if prefix
    else "audio_to_video_attn",
)

audio_to_video_norm instance-attribute

audio_to_video_norm = _make_rms_norm(
    dim, eps=eps, elementwise_affine=elementwise_affine
)

cross_attn_adaln instance-attribute

cross_attn_adaln = (
    video_cross_attn_adaln or audio_cross_attn_adaln
)

ff instance-attribute

ff = LTX2FeedForward(
    dim,
    activation_fn=activation_fn,
    quant_config=quant_config,
    prefix=f"{prefix}.ff" if prefix else "ff",
)

norm1 instance-attribute

norm1 = _make_rms_norm(
    dim, eps=eps, elementwise_affine=elementwise_affine
)

norm2 instance-attribute

norm2 = _make_rms_norm(
    dim, eps=eps, elementwise_affine=elementwise_affine
)

norm3 instance-attribute

norm3 = _make_rms_norm(
    dim, eps=eps, elementwise_affine=elementwise_affine
)

perturbed_attn instance-attribute

perturbed_attn = perturbed_attn

prompt_scale_shift_table instance-attribute

prompt_scale_shift_table = Parameter(randn(2, dim))

scale_shift_table instance-attribute

scale_shift_table = Parameter(
    randn(video_mod_param_num, dim) / dim**0.5
)

video_a2v_cross_attn_scale_shift_table instance-attribute

video_a2v_cross_attn_scale_shift_table = Parameter(
    randn(5, dim)
)

video_cross_attn_adaln instance-attribute

video_cross_attn_adaln = video_cross_attn_adaln

video_to_audio_attn instance-attribute

video_to_audio_attn = LTX2Attention(
    query_dim=audio_dim,
    cross_attention_dim=dim,
    heads=audio_num_attention_heads,
    kv_heads=audio_num_attention_heads,
    dim_head=audio_attention_head_dim,
    bias=attention_bias,
    out_bias=attention_out_bias,
    qk_norm=qk_norm,
    rope_type=rope_type,
    apply_gated_attention=audio_gated_attn,
    quant_config=quant_config,
    prefix=f"{prefix}.video_to_audio_attn"
    if prefix
    else "video_to_audio_attn",
)

video_to_audio_norm instance-attribute

video_to_audio_norm = _make_rms_norm(
    audio_dim,
    eps=eps,
    elementwise_affine=elementwise_affine,
)

forward

forward(
    hidden_states: Tensor,
    audio_hidden_states: Tensor,
    encoder_hidden_states: Tensor,
    audio_encoder_hidden_states: Tensor,
    temb: Tensor,
    temb_audio: Tensor,
    temb_ca_scale_shift: Tensor,
    temb_ca_audio_scale_shift: Tensor,
    temb_ca_gate: Tensor,
    temb_ca_audio_gate: Tensor,
    temb_prompt: Tensor | None = None,
    temb_prompt_audio: Tensor | None = None,
    video_rotary_emb: tuple[Tensor, Tensor] | None = None,
    audio_rotary_emb: tuple[Tensor, Tensor] | None = None,
    ca_video_rotary_emb: tuple[Tensor, Tensor]
    | None = None,
    ca_audio_rotary_emb: tuple[Tensor, Tensor]
    | None = None,
    encoder_attention_mask: Tensor | None = None,
    audio_encoder_attention_mask: Tensor | None = None,
    self_attention_mask: Tensor | None = None,
    audio_self_attention_mask: Tensor | None = None,
    a2v_cross_attention_mask: Tensor | None = None,
    v2a_cross_attention_mask: Tensor | None = None,
    use_a2v_cross_attention: bool = True,
    use_v2a_cross_attention: bool = True,
    perturbation_mask: Tensor | None = None,
    all_perturbed: bool | None = None,
) -> Tensor

get_mod_params staticmethod

get_mod_params(
    scale_shift_table: Tensor, temb: Tensor, batch_size: int
) -> tuple[Tensor, ...]

TensorParallelRMSNorm

Bases: Module

RMSNorm that computes stats across TP shards for q/k norm.

LTX2 uses qk_norm="rms_norm_across_heads" while Q/K are tensor-parallel sharded. A local RMSNorm would compute statistics on only the local shard, which changes the normalization when TP > 1. We all-reduce the squared sum to match the global RMS across all heads.

eps instance-attribute

eps = eps

global_hidden_size instance-attribute

global_hidden_size = hidden_size * max(tp_size, 1)

hidden_size instance-attribute

hidden_size = hidden_size

tp_size instance-attribute

tp_size = tp_size

weight instance-attribute

weight = Parameter(ones(hidden_size))

forward

forward(x: Tensor) -> Tensor

apply_interleaved_rotary_emb

apply_interleaved_rotary_emb(
    x: Tensor, freqs: tuple[Tensor, Tensor]
) -> Tensor

apply_split_rotary_emb

apply_split_rotary_emb(
    x: Tensor,
    freqs: tuple[Tensor, Tensor],
    *,
    head_dim: int,
) -> Tensor