vllm_omni.diffusion.models.cosmos3.transformer_cosmos3 ¶
Cosmos3 VFM Transformer for vllm-omni.
Implements the Mixture-of-Transformers architecture with two pathways: - Understanding (UND): causal self-attention on text tokens (Qwen3-VL backbone) - Generation (GEN): cross-attention where visual Q attends to [K_und, K_gen]
Ported from the TRT-LLM integration (tekit branch user/shreyasm/cosmos3).
Cosmos3CausalAttention ¶
Bases: Module
Understanding pathway: causal self-attention on text tokens.
Returns (output, K, V) where K/V are post-norm, post-RoPE for the generation pathway's cross-attention.
attn instance-attribute ¶
attn = Attention(
num_heads=num_heads,
head_size=head_dim,
causal=True,
softmax_scale=1.0 / head_dim**0.5,
num_kv_heads=num_kv_heads,
skip_sequence_parallel=True,
)
to_k instance-attribute ¶
to_k = ColumnParallelLinear(
hidden_size,
num_kv_heads * head_dim,
bias=False,
gather_output=False,
return_bias=False,
quant_config=quant_config,
prefix=f"{prefix}.to_k",
)
to_out instance-attribute ¶
to_out = RowParallelLinear(
num_heads * head_dim,
hidden_size,
bias=False,
input_is_parallel=True,
return_bias=False,
quant_config=quant_config,
prefix=f"{prefix}.to_out",
)
to_q instance-attribute ¶
to_q = ColumnParallelLinear(
hidden_size,
num_heads * head_dim,
bias=False,
gather_output=False,
return_bias=False,
quant_config=quant_config,
prefix=f"{prefix}.to_q",
)
to_v instance-attribute ¶
to_v = ColumnParallelLinear(
hidden_size,
num_kv_heads * head_dim,
bias=False,
gather_output=False,
return_bias=False,
quant_config=quant_config,
prefix=f"{prefix}.to_v",
)
Cosmos3CrossAttention ¶
Bases: Module
Generation pathway: full attention where visual Q attends to all K/V.
-
Non-SP path: explicit
cat([k_und, k_gen]). Text conditioning is always present because K/V are physically concatenated. -
SP path (Ulysses active):
k_und/v_undare passed asjoint_key/joint_valueinAttentionMetadata. The Ulysses wrapper head-slices the replicated UND K/V and performs all-to-all on the sharded GEN Q/K/V so every query sees the full context.
attn instance-attribute ¶
attn = Attention(
num_heads=num_heads,
head_size=head_dim,
causal=False,
softmax_scale=1.0 / head_dim**0.5,
num_kv_heads=num_kv_heads,
)
to_k instance-attribute ¶
to_k = ColumnParallelLinear(
hidden_size,
num_kv_heads * head_dim,
bias=False,
gather_output=False,
return_bias=False,
quant_config=quant_config,
prefix=f"{prefix}.to_k",
)
to_out instance-attribute ¶
to_out = RowParallelLinear(
num_heads * head_dim,
hidden_size,
bias=False,
input_is_parallel=True,
return_bias=False,
quant_config=quant_config,
prefix=f"{prefix}.to_out",
)
to_q instance-attribute ¶
to_q = ColumnParallelLinear(
hidden_size,
num_heads * head_dim,
bias=False,
gather_output=False,
return_bias=False,
quant_config=quant_config,
prefix=f"{prefix}.to_q",
)
to_v instance-attribute ¶
to_v = ColumnParallelLinear(
hidden_size,
num_kv_heads * head_dim,
bias=False,
gather_output=False,
return_bias=False,
quant_config=quant_config,
prefix=f"{prefix}.to_v",
)
forward ¶
forward(
hidden_states: Tensor,
k_und: Tensor,
v_und: Tensor,
freqs_cos: Tensor,
freqs_sin: Tensor,
) -> Tensor
Parameters:
| Name | Type | Description | Default |
|---|---|---|---|
hidden_states | Tensor | [B, S_gen_local, hidden_size] (may be sequence-sharded) | required |
k_und | Tensor | [B, S_und, H_kv_local, D] pre-computed UND keys (TP-sharded, post-norm/RoPE) | required |
v_und | Tensor | [B, S_und, H_kv_local, D] pre-computed UND values (TP-sharded) | required |
freqs_cos | Tensor | [B, S_gen_local, 1, D] | required |
freqs_sin | Tensor | [B, S_gen_local, 1, D] | required |
Cosmos3GatedMLP ¶
Bases: Module
down_proj instance-attribute ¶
down_proj = RowParallelLinear(
intermediate_size,
hidden_size,
bias=False,
input_is_parallel=True,
return_bias=False,
quant_config=quant_config,
prefix=f"{prefix}.down_proj",
)
gate_proj instance-attribute ¶
gate_proj = ColumnParallelLinear(
hidden_size,
intermediate_size,
bias=False,
gather_output=False,
return_bias=False,
quant_config=quant_config,
prefix=f"{prefix}.gate_proj",
)
up_proj instance-attribute ¶
up_proj = ColumnParallelLinear(
hidden_size,
intermediate_size,
bias=False,
gather_output=False,
return_bias=False,
quant_config=quant_config,
prefix=f"{prefix}.up_proj",
)
Cosmos3GenDecoderLayer ¶
Bases: Module
Generation pathway decoder layer: cross-attention (to UND K/V) + MLP.
cross_attention instance-attribute ¶
cross_attention = Cosmos3CrossAttention(
hidden_size=hidden_size,
num_attention_heads=num_attention_heads,
num_key_value_heads=num_key_value_heads,
head_dim=head_dim,
rms_norm_eps=rms_norm_eps,
quant_config=quant_config,
prefix=f"{prefix}.cross_attention",
)
mlp instance-attribute ¶
mlp = Cosmos3GatedMLP(
hidden_size=hidden_size,
intermediate_size=intermediate_size,
quant_config=quant_config,
prefix=f"{prefix}.mlp",
)
Cosmos3GenSPPrepare ¶
Cosmos3LanguageModel ¶
Bases: Module
Understanding pathway: a standard causal LM that processes text tokens.
Returns per-layer K/V tensors for the generation pathway's cross-attention. The UND pathway is independent of the denoising step, so its K/V can be computed once and reused across all sampling steps.
layers instance-attribute ¶
layers = ModuleList(
[
(
Cosmos3UndDecoderLayer(
hidden_size=hidden_size,
intermediate_size=intermediate_size,
num_attention_heads=num_attention_heads,
num_key_value_heads=num_key_value_heads,
head_dim=head_dim,
rms_norm_eps=rms_norm_eps,
quant_config=quant_config,
prefix=f"{prefix}.layers.{i}",
)
)
for i in (range(num_hidden_layers))
]
)
rotary_emb instance-attribute ¶
rotary_emb = Qwen3VLTextRotaryEmbedding(
head_dim=head_dim,
rope_theta=rope_theta,
mrope_section=mrope_section,
)
forward ¶
Parameters:
| Name | Type | Description | Default |
|---|---|---|---|
text_ids | Tensor | [B, S] token IDs | required |
freqs | tuple[Tensor, Tensor] | (cos, sin) each [B, S, 1, D] | required |
Returns:
| Type | Description |
|---|---|
list[tuple[Tensor, Tensor]] | List of (K, V) per layer, each [B, S, H_kv, D]. |
No padding mask is applied: with right-padding + causal self-attention, real query positions only attend to real keys, and the caller trims pad K/V via max_real_len before the GEN cross-attention sees them.
Cosmos3UndDecoderLayer ¶
Bases: Module
Understanding pathway decoder layer: causal self-attention + MLP.
mlp instance-attribute ¶
mlp = Cosmos3GatedMLP(
hidden_size=hidden_size,
intermediate_size=intermediate_size,
quant_config=quant_config,
prefix=f"{prefix}.mlp",
)
post_attention_layernorm instance-attribute ¶
post_attention_layernorm = RMSNorm(
hidden_size, eps=rms_norm_eps
)
self_attn instance-attribute ¶
self_attn = Cosmos3CausalAttention(
hidden_size=hidden_size,
num_attention_heads=num_attention_heads,
num_key_value_heads=num_key_value_heads,
head_dim=head_dim,
rms_norm_eps=rms_norm_eps,
quant_config=quant_config,
prefix=f"{prefix}.self_attn",
)
Cosmos3VFMTransformer ¶
Bases: Module
Cosmos3 VFM Transformer: UND language model + GEN denoising layers.
The UND pathway runs once per generation (K/V cached). The GEN pathway runs at each denoising step.
Layerwise offloading uses gen_layers as the block container.
Sequence parallelism uses _sp_plan to shard/gather the GEN pathway at module boundaries. Cosmos3CrossAttention checks forward_context.sp_active at runtime and routes to the framework Attention layer (with Ulysses all-to-all) or plain SDPA accordingly.
action_dim instance-attribute ¶
action_dim = int(
action_dim_value if action_dim_value is not None else 64
)
action_gen instance-attribute ¶
action_modality_embed instance-attribute ¶
action_proj_in instance-attribute ¶
action_proj_in = DomainAwareLinear(
action_dim,
hidden_size,
num_embodiment_domains,
dtype=dtype,
)
action_proj_out instance-attribute ¶
action_proj_out = DomainAwareLinear(
hidden_size,
action_dim,
num_embodiment_domains,
dtype=dtype,
)
enable_fps_modulation instance-attribute ¶
enable_fps_modulation = bool(
_tf_config_get(
model_config, "enable_fps_modulation", True
)
)
gen_layers instance-attribute ¶
gen_layers = ModuleList(
[
(
Cosmos3GenDecoderLayer(
layer_idx=i,
hidden_size=hidden_size,
intermediate_size=intermediate_size,
num_attention_heads=num_attention_heads,
num_key_value_heads=num_key_value_heads,
head_dim=head_dim,
rms_norm_eps=rms_norm_eps,
quant_config=quant_config,
prefix=f"gen_layers.{i}",
)
)
for i in (range(num_hidden_layers))
]
)
hidden_size instance-attribute ¶
hidden_size = int(
_tf_config_get(model_config, "hidden_size", 4096)
)
intermediate_size instance-attribute ¶
intermediate_size = int(
_tf_config_get(model_config, "intermediate_size", 12288)
)
language_model instance-attribute ¶
language_model = Cosmos3LanguageModel(
hidden_size=hidden_size,
intermediate_size=intermediate_size,
num_hidden_layers=num_hidden_layers,
num_attention_heads=num_attention_heads,
num_key_value_heads=num_key_value_heads,
head_dim=head_dim,
vocab_size=vocab_size,
rms_norm_eps=rms_norm_eps,
rope_theta=rope_theta,
mrope_section=mrope_section,
quant_config=quant_config,
prefix="language_model",
)
latent_channel_size instance-attribute ¶
latent_channel_size = int(
_tf_config_get(model_config, "latent_channel", 48)
)
latent_patch_size instance-attribute ¶
latent_patch_size = int(
_tf_config_get(model_config, "latent_patch_size", 2)
)
num_attention_heads instance-attribute ¶
num_attention_heads = int(
_tf_config_get(model_config, "num_attention_heads", 32)
)
num_embodiment_domains instance-attribute ¶
num_embodiment_domains = int(
_od_config_get(od_config, "num_embodiment_domains", 32)
)
num_hidden_layers instance-attribute ¶
num_hidden_layers = int(
_tf_config_get(model_config, "num_hidden_layers", 36)
)
num_key_value_heads instance-attribute ¶
num_key_value_heads = int(
_tf_config_get(model_config, "num_key_value_heads", 8)
)
patch_latent_dim instance-attribute ¶
rms_norm_eps instance-attribute ¶
rms_norm_eps = float(
_tf_config_get(model_config, "rms_norm_eps", 1e-06)
)
rope_theta instance-attribute ¶
rope_theta = float(
_tf_config_get(model_config, "rope_theta", 5000000)
)
temporal_compression_factor instance-attribute ¶
temporal_compression_factor = int(
temporal_compression_factor
)
temporal_compression_factor_sound instance-attribute ¶
temporal_compression_factor_sound = int(
_tf_config_get(
model_config, "temporal_compression_factor_sound", 1
)
)
temporal_modality_margin instance-attribute ¶
temporal_modality_margin = int(
_tf_config_get(
model_config,
"unified_3d_mrope_temporal_modality_margin",
15000,
)
)
time_embedder instance-attribute ¶
time_embedder = TimestepEmbedder(
hidden_size, target_dtype=dtype
)
timestep_scale instance-attribute ¶
timestep_scale = float(
_tf_config_get(model_config, "timestep_scale", 0.001)
)
vocab_size instance-attribute ¶
vocab_size = int(
_tf_config_get(model_config, "vocab_size", 151936)
)
forward ¶
forward(
hidden_states: Tensor,
timestep: Tensor,
text_ids: Tensor,
text_mask: Tensor,
video_shape: tuple[int, int, int],
fps: float | None = None,
action_latents: Tensor | None = None,
action_domain_ids: Tensor | None = None,
action_noisy_mask: Tensor | None = None,
action_start_frame_offset: int = 1,
action_fps: float | None = None,
sound_latents: Tensor | None = None,
noisy_frame_mask: Tensor | None = None,
**kwargs,
) -> Tensor | tuple[Tensor, ...]
Parameters:
| Name | Type | Description | Default |
|---|---|---|---|
hidden_states | Tensor | [B, C, t, h, w] noisy latents | required |
timestep | Tensor | [B] diffusion timestep | required |
text_ids | Tensor | [B, S_text] tokenized text | required |
text_mask | Tensor | [B, S_text] attention mask (1=real, 0=pad) | required |
video_shape | tuple[int, int, int] | (t, h, w) in latent space | required |
fps | float | None | video frame rate for temporal mRoPE modulation | None |
action_latents | Tensor | None | Optional [B, T_action, D_action] noisy action latents. | None |
action_domain_ids | Tensor | None | Optional [B] embodiment domain IDs for action projections. | None |
action_noisy_mask | Tensor | None | Optional [B, T_action, 1] mask where 1=noisy action token and 0=clean conditioned token. | None |
sound_latents | Tensor | None | Optional [B, C_sound, T_sound] noisy sound latents. | None |
noisy_frame_mask | Tensor | None | Optional [B, 1, t, 1, 1] mask where 1=noisy (add timestep embedding, predict velocity) and 0=conditioned (clean context, skip timestep embedding). None means all frames noisy (T2V mode). | None |
Returns:
| Type | Description |
|---|---|
Tensor | tuple[Tensor, ...] | [B, C, t, h, w] velocity prediction, or |
Tensor | tuple[Tensor, ...] | tuple outputs in video, action, sound order when extra modalities are provided. |
pack_action ¶
Validate and return action latents as [B, T_action, D_action] tokens.
pack_sound ¶
[B, C_sound, T_sound] -> [B, T_sound, C_sound].
patchify ¶
[B, C, t, h, w] -> [B, thpwp, ppC], padding h/w if needed.
unpack_action staticmethod ¶
Return [B, T_action, D_action] action predictions.
unpack_sound staticmethod ¶
[B, T_sound, C_sound] -> [B, C_sound, T_sound].
DomainAwareLinear ¶
Bases: Module
Linear projection with one weight/bias pair per action embodiment domain.
Qwen3VLTextRotaryEmbedding ¶
Bases: Module
Multi-dimensional rotary position embedding for Qwen3-VL.
apply_interleaved_mrope ¶
Reorganize from chunked [TTT...HHH...WWW] to interleaved [THTHW...].
TimestepEmbedder ¶
Bases: Module
Embeds scalar timesteps into vector representations.
compute_mrope_position_ids_action ¶
compute_mrope_position_ids_action(
grid_t: int,
temporal_offset: int | float,
action_fps: float | None,
base_fps: float = 24.0,
base_temporal_compression_factor: int = 4,
enable_fps_modulation: bool = True,
start_frame_offset: int = 1,
) -> tuple[Tensor, int | float]
Generate mRoPE IDs for action tokens as a frame-rate (T, 1, 1) grid.
compute_mrope_position_ids_sound ¶
compute_mrope_position_ids_sound(
grid_t: int,
temporal_offset: int | float,
sound_latent_fps: float,
base_fps: float = 24.0,
temporal_compression_factor_sound: int = 1,
enable_fps_modulation: bool = True,
) -> tuple[Tensor, int | float]
Generate mRoPE IDs for sound tokens as a (T, 1, 1) grid.
compute_mrope_position_ids_text ¶
Generate 3D mRoPE position IDs for text tokens.
Text tokens: all three axes (t, h, w) share the same monotonically increasing position IDs.
compute_mrope_position_ids_vision ¶
compute_mrope_position_ids_vision(
grid_t: int,
grid_h: int,
grid_w: int,
temporal_offset: int | float,
fps: float | None = None,
base_fps: float = 24.0,
temporal_compression_factor: int = 4,
base_temporal_compression_factor: int | None = None,
enable_fps_modulation: bool = True,
start_frame_offset: int = 0,
) -> tuple[Tensor, int | float]
Generate 3D mRoPE position IDs for vision tokens.
Creates a (t, h, w) position grid with spatial indices reset per segment (Qwen3VL-style). Flattened in t-major order.