vllm_omni.diffusion.models.ltx2.ltx2_transformer ¶
AudioVisualModelOutput dataclass ¶
Bases: BaseOutput
Holds the output of an audiovisual model which produces both visual (e.g. video) and audio outputs.
Parameters:
| Name | Type | Description | Default |
|---|---|---|---|
sample | `torch.Tensor` of shape `(batch_size, num_video_tokens, out_channels)` | The patchified visual output conditioned on the | required |
audio_sample | `torch.Tensor` of shape `(batch_size, num_audio_tokens, audio_out_channels)` | The patchified audio output of the audiovisual model before the pipeline unpacks it back into audio latent dimensions. | required |
ColumnParallelApproxGELU ¶
Bases: Module
LTX2AdaLayerNormSingle ¶
Bases: Module
Norm layer adaptive layer norm single (adaLN-single).
As proposed in PixArt-Alpha (see: https://huggingface.co/papers/2310.00426; Section 2.3) and adapted by the LTX-2.0 model. In particular, the number of modulation parameters to be calculated is now configurable.
Parameters:
| Name | Type | Description | Default |
|---|---|---|---|
embedding_dim | `int` | The size of each embedding vector. | required |
num_mod_params | `int`, *optional*, defaults to `6` | The number of modulation parameters which will be calculated in the first return argument. The default of 6 is standard, but sometimes we may want to have a different (usually smaller) number of modulation parameters. | 6 |
use_additional_conditions | `bool`, *optional*, defaults to `False` | Whether to use additional conditions for normalization or not. | False |
LTX2Attention ¶
Bases: Module
Attention class for all LTX-2.0 attention layers. Compared to LTX-1.0, this supports specifying the query and key RoPE embeddings separately for audio-to-video (a2v) and video-to-audio (v2a) cross-attention.
attn instance-attribute ¶
attn = Attention(
num_heads=query_num_heads,
head_size=dim_head,
num_kv_heads=kv_num_heads,
softmax_scale=1.0 / dim_head**0.5,
causal=False,
prefix=prefix,
disable_kv_quant=disable_kv_quant,
)
cross_attention_dim instance-attribute ¶
norm_k instance-attribute ¶
norm_k = TensorParallelRMSNorm(
dim_head * kv_num_heads,
eps=norm_eps,
elementwise_affine=norm_elementwise_affine,
tp_size=tp_size,
)
norm_q instance-attribute ¶
norm_q = TensorParallelRMSNorm(
dim_head * query_num_heads,
eps=norm_eps,
elementwise_affine=norm_elementwise_affine,
tp_size=tp_size,
)
to_out instance-attribute ¶
to_out = ModuleList(
[
RowParallelLinear(
inner_dim,
out_dim,
bias=out_bias,
input_is_parallel=True,
return_bias=False,
quant_config=quant_config,
prefix=f"{prefix}.to_out.0"
if prefix
else "to_out.0",
),
Dropout(dropout) if dropout > 0 else Identity(),
]
)
forward ¶
forward(
hidden_states: Tensor,
encoder_hidden_states: Tensor | None = None,
attention_mask: Tensor | None = None,
query_rotary_emb: tuple[Tensor, Tensor] | None = None,
key_rotary_emb: tuple[Tensor, Tensor] | None = None,
**kwargs,
) -> Tensor
LTX2AudioVideoAttnProcessor ¶
Processor for implementing attention (SDPA is used by default if you're using PyTorch 2.0) for the LTX-2.0 model. Compared to the LTX-1.0 model, we allow the RoPE embeddings for the queries and keys to be separate so that we can support audio-to-video (a2v) and video-to-audio (v2a) cross attention.
LTX2AudioVideoRotaryPosEmbed ¶
Bases: Module
Video and audio rotary positional embeddings (RoPE) for the LTX-2.0 model.
Parameters:
| Name | Type | Description | Default |
|---|---|---|---|
causal_offset | `int`, *optional*, defaults to `1` | Offset in the temporal axis for causal VAE modeling. This is typically 1 (for causal modeling where the VAE treats the very first frame differently), but could also be 0 (for non-causal modeling). | 1 |
audio_latents_per_second instance-attribute ¶
prepare_audio_coords ¶
Create per-dimension bounds [inclusive start, exclusive end) of start and end timestamps for each latent frame. This will ultimately have shape (batch_size, 3, num_patches, 2) where - axis 1 (size 1) represents the temporal dimension - axis 3 (size 2) stores [start, end) indices within each dimension
Parameters:
| Name | Type | Description | Default |
|---|---|---|---|
batch_size | `int` | Batch size of the audio latents. | required |
num_frames | `int` | Number of latent frames in the audio latents. | required |
device | `torch.device` | Device on which to create the audio grid. | required |
shift | `int`, *optional*, defaults to `0` | Offset on the latent indices. Different shift values correspond to different overlapping windows with respect to the same underlying latent grid. | 0 |
Returns:
| Type | Description |
|---|---|
Tensor |
|
prepare_video_coords ¶
prepare_video_coords(
batch_size: int,
num_frames: int,
height: int,
width: int,
device: device,
fps: float = 24.0,
) -> Tensor
Create per-dimension bounds [inclusive start, exclusive end) for each patch with respect to the original pixel space video grid (num_frames, height, width). This will ultimately have shape (batch_size, 3, num_patches, 2) where - axis 1 (size 3) enumerates (frame, height, width) dimensions (e.g. idx 0 corresponds to frames) - axis 3 (size 2) stores [start, end) indices within each dimension
Parameters:
| Name | Type | Description | Default |
|---|---|---|---|
batch_size | `int` | Batch size of the video latents. | required |
num_frames | `int` | Number of latent frames in the video latents. | required |
height | `int` | Latent height of the video latents. | required |
width | `int` | Latent width of the video latents. | required |
device | `torch.device` | Device on which to create the video grid. | required |
Returns:
| Type | Description |
|---|---|
Tensor |
|
LTX2FeedForward ¶
LTX2VideoTransformer3DModel ¶
Bases: Module
A Transformer model for video-like data used in LTX.
Parameters:
| Name | Type | Description | Default |
|---|---|---|---|
in_channels | `int`, defaults to `128` | The number of channels in the input. | 128 |
out_channels | `int`, defaults to `128` | The number of channels in the output. | 128 |
patch_size | `int`, defaults to `1` | The size of the spatial patches to use in the patch embedding layer. | 1 |
patch_size_t | `int`, defaults to `1` | The size of the tmeporal patches to use in the patch embedding layer. | 1 |
num_attention_heads | `int`, defaults to `32` | The number of heads to use for multi-head attention. | 32 |
attention_head_dim | `int`, defaults to `64` | The number of channels in each head. | 128 |
cross_attention_dim | `int`, defaults to `2048 ` | The number of channels for cross attention heads. | 4096 |
num_layers | `int`, defaults to `28` | The number of layers of Transformer blocks to use. | 48 |
activation_fn | `str`, defaults to `"gelu-approximate"` | Activation function to use in feed-forward. | 'gelu-approximate' |
qk_norm | `str`, defaults to `"rms_norm_across_heads"` | The normalization layer to use. | 'rms_norm_across_heads' |
audio_caption_projection instance-attribute ¶
audio_caption_projection = PixArtAlphaTextProjection(
in_features=caption_channels,
hidden_size=audio_inner_dim,
)
audio_norm_out instance-attribute ¶
audio_prompt_adaln instance-attribute ¶
audio_prompt_adaln = LTX2AdaLayerNormSingle(
audio_inner_dim,
num_mod_params=2,
use_additional_conditions=False,
)
audio_rope instance-attribute ¶
audio_rope = LTX2AudioVideoRotaryPosEmbed(
dim=audio_inner_dim,
patch_size=audio_patch_size,
patch_size_t=audio_patch_size_t,
base_num_frames=audio_pos_embed_max_pos,
sampling_rate=audio_sampling_rate,
hop_length=audio_hop_length,
scale_factors=[audio_scale_factor],
theta=rope_theta,
causal_offset=causal_offset,
modality="audio",
double_precision=rope_double_precision,
rope_type=rope_type,
num_attention_heads=audio_num_attention_heads,
)
audio_scale_shift_table instance-attribute ¶
audio_time_embed instance-attribute ¶
audio_time_embed = LTX2AdaLayerNormSingle(
audio_inner_dim,
num_mod_params=audio_num_mod_params,
use_additional_conditions=False,
)
av_cross_attn_audio_scale_shift instance-attribute ¶
av_cross_attn_audio_scale_shift = LTX2AdaLayerNormSingle(
audio_inner_dim,
num_mod_params=4,
use_additional_conditions=False,
)
av_cross_attn_audio_v2a_gate instance-attribute ¶
av_cross_attn_audio_v2a_gate = LTX2AdaLayerNormSingle(
audio_inner_dim,
num_mod_params=1,
use_additional_conditions=False,
)
av_cross_attn_video_a2v_gate instance-attribute ¶
av_cross_attn_video_a2v_gate = LTX2AdaLayerNormSingle(
inner_dim,
num_mod_params=1,
use_additional_conditions=False,
)
av_cross_attn_video_scale_shift instance-attribute ¶
av_cross_attn_video_scale_shift = LTX2AdaLayerNormSingle(
inner_dim,
num_mod_params=4,
use_additional_conditions=False,
)
caption_projection instance-attribute ¶
caption_projection = PixArtAlphaTextProjection(
in_features=caption_channels, hidden_size=inner_dim
)
config instance-attribute ¶
config = SimpleNamespace(
in_channels=in_channels,
out_channels=out_channels,
patch_size=patch_size,
patch_size_t=patch_size_t,
num_attention_heads=num_attention_heads,
attention_head_dim=attention_head_dim,
cross_attention_dim=cross_attention_dim,
vae_scale_factors=vae_scale_factors,
pos_embed_max_pos=pos_embed_max_pos,
base_height=base_height,
base_width=base_width,
audio_in_channels=audio_in_channels,
audio_out_channels=audio_out_channels,
audio_patch_size=audio_patch_size,
audio_patch_size_t=audio_patch_size_t,
audio_num_attention_heads=audio_num_attention_heads,
audio_attention_head_dim=audio_attention_head_dim,
audio_cross_attention_dim=audio_cross_attention_dim,
audio_scale_factor=audio_scale_factor,
audio_pos_embed_max_pos=audio_pos_embed_max_pos,
audio_sampling_rate=audio_sampling_rate,
audio_hop_length=audio_hop_length,
num_layers=num_layers,
activation_fn=activation_fn,
qk_norm=qk_norm,
norm_elementwise_affine=norm_elementwise_affine,
norm_eps=norm_eps,
caption_channels=caption_channels,
attention_bias=attention_bias,
attention_out_bias=attention_out_bias,
rope_theta=rope_theta,
rope_double_precision=rope_double_precision,
causal_offset=causal_offset,
timestep_scale_multiplier=timestep_scale_multiplier,
cross_attn_timestep_scale_multiplier=cross_attn_timestep_scale_multiplier,
rope_type=rope_type,
)
cross_attn_audio_rope instance-attribute ¶
cross_attn_audio_rope = LTX2AudioVideoRotaryPosEmbed(
dim=audio_cross_attention_dim,
patch_size=audio_patch_size,
patch_size_t=audio_patch_size_t,
base_num_frames=cross_attn_pos_embed_max_pos,
sampling_rate=audio_sampling_rate,
hop_length=audio_hop_length,
theta=rope_theta,
causal_offset=causal_offset,
modality="audio",
double_precision=rope_double_precision,
rope_type=rope_type,
num_attention_heads=audio_num_attention_heads,
)
cross_attn_rope instance-attribute ¶
cross_attn_rope = LTX2AudioVideoRotaryPosEmbed(
dim=audio_cross_attention_dim,
patch_size=patch_size,
patch_size_t=patch_size_t,
base_num_frames=cross_attn_pos_embed_max_pos,
base_height=base_height,
base_width=base_width,
theta=rope_theta,
causal_offset=causal_offset,
modality="video",
double_precision=rope_double_precision,
rope_type=rope_type,
num_attention_heads=num_attention_heads,
)
norm_out instance-attribute ¶
packed_modules_mapping class-attribute instance-attribute ¶
prompt_adaln instance-attribute ¶
prompt_adaln = LTX2AdaLayerNormSingle(
inner_dim,
num_mod_params=2,
use_additional_conditions=False,
)
rope instance-attribute ¶
rope = LTX2AudioVideoRotaryPosEmbed(
dim=inner_dim,
patch_size=patch_size,
patch_size_t=patch_size_t,
base_num_frames=pos_embed_max_pos,
base_height=base_height,
base_width=base_width,
scale_factors=vae_scale_factors,
theta=rope_theta,
causal_offset=causal_offset,
modality="video",
double_precision=rope_double_precision,
rope_type=rope_type,
num_attention_heads=num_attention_heads,
)
scale_shift_table instance-attribute ¶
time_embed instance-attribute ¶
time_embed = LTX2AdaLayerNormSingle(
inner_dim,
num_mod_params=video_num_mod_params,
use_additional_conditions=False,
)
transformer_blocks instance-attribute ¶
transformer_blocks = ModuleList(
[
(
LTX2VideoTransformerBlock(
dim=inner_dim,
num_attention_heads=num_attention_heads,
attention_head_dim=attention_head_dim,
cross_attention_dim=cross_attention_dim,
audio_dim=audio_inner_dim,
audio_num_attention_heads=audio_num_attention_heads,
audio_attention_head_dim=audio_attention_head_dim,
audio_cross_attention_dim=audio_cross_attention_dim,
video_gated_attn=gated_attn,
video_cross_attn_adaln=cross_attn_mod,
audio_gated_attn=audio_gated_attn,
audio_cross_attn_adaln=audio_cross_attn_mod,
qk_norm=qk_norm,
activation_fn=activation_fn,
attention_bias=attention_bias,
attention_out_bias=attention_out_bias,
eps=norm_eps,
elementwise_affine=norm_elementwise_affine,
rope_type=rope_type,
perturbed_attn=perturbed_attn,
quant_config=quant_config,
prefix=f"transformer_blocks.{layer_idx}",
)
)
for layer_idx in (range(num_layers))
]
)
forward ¶
forward(
hidden_states: Tensor,
audio_hidden_states: Tensor,
encoder_hidden_states: Tensor,
audio_encoder_hidden_states: Tensor,
timestep: LongTensor,
audio_timestep: LongTensor | None = None,
sigma: Tensor | None = None,
audio_sigma: Tensor | None = None,
encoder_attention_mask: Tensor | None = None,
audio_encoder_attention_mask: Tensor | None = None,
num_frames: int | None = None,
height: int | None = None,
width: int | None = None,
fps: float = 24.0,
audio_num_frames: int | None = None,
video_coords: Tensor | None = None,
audio_coords: Tensor | None = None,
attention_kwargs: dict[str, Any] | None = None,
return_dict: bool = True,
**kwargs,
) -> Tensor
Forward pass for LTX-2.0 audiovisual video transformer.
Parameters:
| Name | Type | Description | Default |
|---|---|---|---|
hidden_states | `torch.Tensor` | Input patchified video latents of shape | required |
audio_hidden_states | `torch.Tensor` | Input patchified audio latents of shape | required |
encoder_hidden_states | `torch.Tensor` | Input video text embeddings of shape | required |
audio_encoder_hidden_states | `torch.Tensor` | Input audio text embeddings of shape | required |
timestep | `torch.Tensor` | Input timestep of shape | required |
audio_timestep | `torch.Tensor`, *optional* | Input timestep of shape | None |
encoder_attention_mask | `torch.Tensor`, *optional* | Optional multiplicative text attention mask of shape | None |
audio_encoder_attention_mask | `torch.Tensor`, *optional* | Optional multiplicative text attention mask of shape | None |
num_frames | `int`, *optional* | The number of latent video frames. Used if calculating the video coordinates for RoPE. | None |
height | `int`, *optional* | The latent video height. Used if calculating the video coordinates for RoPE. | None |
width | `int`, *optional* | The latent video width. Used if calculating the video coordinates for RoPE. | None |
fps | float | ( | 24.0 |
audio_num_frames | int | None | ( | None |
video_coords | `torch.Tensor`, *optional* | The video coordinates to be used when calculating the rotary positional embeddings (RoPE) of shape | None |
audio_coords | `torch.Tensor`, *optional* | The audio coordinates to be used when calculating the rotary positional embeddings (RoPE) of shape | None |
attention_kwargs | `Dict[str, Any]`, *optional* | Optional dict of keyword args to be passed to the attention processor. | None |
return_dict | `bool`, *optional*, defaults to `True` | Whether to return a dict-like structured output of type | True |
Returns:
| Type | Description |
|---|---|
Tensor |
|
LTX2VideoTransformerBlock ¶
Bases: Module
Transformer block used in LTX-2.0.
Parameters:
| Name | Type | Description | Default |
|---|---|---|---|
dim | `int` | The number of channels in the input and output. | required |
num_attention_heads | `int` | The number of heads to use for multi-head attention. | required |
attention_head_dim | `int` | The number of channels in each head. | required |
qk_norm | `str`, defaults to `"rms_norm"` | The normalization layer to use. | 'rms_norm_across_heads' |
activation_fn | `str`, defaults to `"gelu-approximate"` | Activation function to use in feed-forward. | 'gelu-approximate' |
eps | `float`, defaults to `1e-6` | Epsilon value for normalization layers. | 1e-06 |
attn1 instance-attribute ¶
attn1 = LTX2Attention(
query_dim=dim,
heads=num_attention_heads,
kv_heads=num_attention_heads,
dim_head=attention_head_dim,
bias=attention_bias,
cross_attention_dim=None,
out_bias=attention_out_bias,
qk_norm=qk_norm,
rope_type=rope_type,
apply_gated_attention=video_gated_attn,
quant_config=quant_config,
prefix=f"{prefix}.attn1" if prefix else "attn1",
)
attn2 instance-attribute ¶
attn2 = LTX2Attention(
query_dim=dim,
cross_attention_dim=cross_attention_dim,
heads=num_attention_heads,
kv_heads=num_attention_heads,
dim_head=attention_head_dim,
bias=attention_bias,
out_bias=attention_out_bias,
qk_norm=qk_norm,
rope_type=rope_type,
apply_gated_attention=video_gated_attn,
quant_config=quant_config,
prefix=f"{prefix}.attn2" if prefix else "attn2",
disable_kv_quant=True,
)
audio_a2v_cross_attn_scale_shift_table instance-attribute ¶
audio_attn1 instance-attribute ¶
audio_attn1 = LTX2Attention(
query_dim=audio_dim,
heads=audio_num_attention_heads,
kv_heads=audio_num_attention_heads,
dim_head=audio_attention_head_dim,
bias=attention_bias,
cross_attention_dim=None,
out_bias=attention_out_bias,
qk_norm=qk_norm,
rope_type=rope_type,
apply_gated_attention=audio_gated_attn,
quant_config=quant_config,
prefix=f"{prefix}.audio_attn1"
if prefix
else "audio_attn1",
)
audio_attn2 instance-attribute ¶
audio_attn2 = LTX2Attention(
query_dim=audio_dim,
cross_attention_dim=audio_cross_attention_dim,
heads=audio_num_attention_heads,
kv_heads=audio_num_attention_heads,
dim_head=audio_attention_head_dim,
bias=attention_bias,
out_bias=attention_out_bias,
qk_norm=qk_norm,
rope_type=rope_type,
apply_gated_attention=audio_gated_attn,
quant_config=quant_config,
prefix=f"{prefix}.audio_attn2"
if prefix
else "audio_attn2",
disable_kv_quant=True,
)
audio_ff instance-attribute ¶
audio_ff = LTX2FeedForward(
audio_dim,
activation_fn=activation_fn,
quant_config=quant_config,
prefix=f"{prefix}.audio_ff" if prefix else "audio_ff",
)
audio_norm1 instance-attribute ¶
audio_norm2 instance-attribute ¶
audio_norm3 instance-attribute ¶
audio_prompt_scale_shift_table instance-attribute ¶
audio_scale_shift_table instance-attribute ¶
audio_to_video_attn instance-attribute ¶
audio_to_video_attn = LTX2Attention(
query_dim=dim,
cross_attention_dim=audio_dim,
heads=audio_num_attention_heads,
kv_heads=audio_num_attention_heads,
dim_head=audio_attention_head_dim,
bias=attention_bias,
out_bias=attention_out_bias,
qk_norm=qk_norm,
rope_type=rope_type,
apply_gated_attention=video_gated_attn,
quant_config=quant_config,
prefix=f"{prefix}.audio_to_video_attn"
if prefix
else "audio_to_video_attn",
)
audio_to_video_norm instance-attribute ¶
cross_attn_adaln instance-attribute ¶
ff instance-attribute ¶
ff = LTX2FeedForward(
dim,
activation_fn=activation_fn,
quant_config=quant_config,
prefix=f"{prefix}.ff" if prefix else "ff",
)
norm1 instance-attribute ¶
norm2 instance-attribute ¶
norm3 instance-attribute ¶
scale_shift_table instance-attribute ¶
video_a2v_cross_attn_scale_shift_table instance-attribute ¶
video_to_audio_attn instance-attribute ¶
video_to_audio_attn = LTX2Attention(
query_dim=audio_dim,
cross_attention_dim=dim,
heads=audio_num_attention_heads,
kv_heads=audio_num_attention_heads,
dim_head=audio_attention_head_dim,
bias=attention_bias,
out_bias=attention_out_bias,
qk_norm=qk_norm,
rope_type=rope_type,
apply_gated_attention=audio_gated_attn,
quant_config=quant_config,
prefix=f"{prefix}.video_to_audio_attn"
if prefix
else "video_to_audio_attn",
)
video_to_audio_norm instance-attribute ¶
forward ¶
forward(
hidden_states: Tensor,
audio_hidden_states: Tensor,
encoder_hidden_states: Tensor,
audio_encoder_hidden_states: Tensor,
temb: Tensor,
temb_audio: Tensor,
temb_ca_scale_shift: Tensor,
temb_ca_audio_scale_shift: Tensor,
temb_ca_gate: Tensor,
temb_ca_audio_gate: Tensor,
temb_prompt: Tensor | None = None,
temb_prompt_audio: Tensor | None = None,
video_rotary_emb: tuple[Tensor, Tensor] | None = None,
audio_rotary_emb: tuple[Tensor, Tensor] | None = None,
ca_video_rotary_emb: tuple[Tensor, Tensor]
| None = None,
ca_audio_rotary_emb: tuple[Tensor, Tensor]
| None = None,
encoder_attention_mask: Tensor | None = None,
audio_encoder_attention_mask: Tensor | None = None,
self_attention_mask: Tensor | None = None,
audio_self_attention_mask: Tensor | None = None,
a2v_cross_attention_mask: Tensor | None = None,
v2a_cross_attention_mask: Tensor | None = None,
use_a2v_cross_attention: bool = True,
use_v2a_cross_attention: bool = True,
perturbation_mask: Tensor | None = None,
all_perturbed: bool | None = None,
) -> Tensor
TensorParallelRMSNorm ¶
Bases: Module
RMSNorm that computes stats across TP shards for q/k norm.
LTX2 uses qk_norm="rms_norm_across_heads" while Q/K are tensor-parallel sharded. A local RMSNorm would compute statistics on only the local shard, which changes the normalization when TP > 1. We all-reduce the squared sum to match the global RMS across all heads.