Skip to content

vllm_omni.diffusion.models.cosmos3.transformer_cosmos3

Cosmos3 VFM Transformer for vllm-omni.

Implements the Mixture-of-Transformers architecture with two pathways: - Understanding (UND): causal self-attention on text tokens (Qwen3-VL backbone) - Generation (GEN): cross-attention where visual Q attends to [K_und, K_gen]

Ported from the TRT-LLM integration (tekit branch user/shreyasm/cosmos3).

logger module-attribute

logger = init_logger(__name__)

Cosmos3CausalAttention

Bases: Module

Understanding pathway: causal self-attention on text tokens.

Returns (output, K, V) where K/V are post-norm, post-RoPE for the generation pathway's cross-attention.

attn instance-attribute

attn = Attention(
    num_heads=num_heads,
    head_size=head_dim,
    causal=True,
    softmax_scale=1.0 / head_dim**0.5,
    num_kv_heads=num_kv_heads,
    skip_sequence_parallel=True,
)

head_dim instance-attribute

head_dim = head_dim

hidden_size instance-attribute

hidden_size = hidden_size

norm_k instance-attribute

norm_k = RMSNorm(head_dim, eps=rms_norm_eps)

norm_q instance-attribute

norm_q = RMSNorm(head_dim, eps=rms_norm_eps)

num_heads instance-attribute

num_heads = num_attention_heads

num_heads_local instance-attribute

num_heads_local = num_heads // tp_size

num_kv_heads instance-attribute

num_kv_heads = num_key_value_heads

num_kv_heads_local instance-attribute

num_kv_heads_local = num_kv_heads // tp_size

to_k instance-attribute

to_k = ColumnParallelLinear(
    hidden_size,
    num_kv_heads * head_dim,
    bias=False,
    gather_output=False,
    return_bias=False,
    quant_config=quant_config,
    prefix=f"{prefix}.to_k",
)

to_out instance-attribute

to_out = RowParallelLinear(
    num_heads * head_dim,
    hidden_size,
    bias=False,
    input_is_parallel=True,
    return_bias=False,
    quant_config=quant_config,
    prefix=f"{prefix}.to_out",
)

to_q instance-attribute

to_q = ColumnParallelLinear(
    hidden_size,
    num_heads * head_dim,
    bias=False,
    gather_output=False,
    return_bias=False,
    quant_config=quant_config,
    prefix=f"{prefix}.to_q",
)

to_v instance-attribute

to_v = ColumnParallelLinear(
    hidden_size,
    num_kv_heads * head_dim,
    bias=False,
    gather_output=False,
    return_bias=False,
    quant_config=quant_config,
    prefix=f"{prefix}.to_v",
)

forward

forward(
    hidden_states: Tensor,
    freqs_cos: Tensor,
    freqs_sin: Tensor,
) -> tuple[Tensor, Tensor, Tensor]

Cosmos3CrossAttention

Bases: Module

Generation pathway: full attention where visual Q attends to all K/V.

  • Non-SP path: explicit cat([k_und, k_gen]). Text conditioning is always present because K/V are physically concatenated.

  • SP path (Ulysses active): k_und/v_und are passed as joint_key/joint_value in AttentionMetadata. The Ulysses wrapper head-slices the replicated UND K/V and performs all-to-all on the sharded GEN Q/K/V so every query sees the full context.

attn instance-attribute

attn = Attention(
    num_heads=num_heads,
    head_size=head_dim,
    causal=False,
    softmax_scale=1.0 / head_dim**0.5,
    num_kv_heads=num_kv_heads,
)

head_dim instance-attribute

head_dim = head_dim

hidden_size instance-attribute

hidden_size = hidden_size

norm_k instance-attribute

norm_k = RMSNorm(head_dim, eps=rms_norm_eps)

norm_q instance-attribute

norm_q = RMSNorm(head_dim, eps=rms_norm_eps)

num_heads instance-attribute

num_heads = num_attention_heads

num_heads_local instance-attribute

num_heads_local = num_heads // tp_size

num_kv_heads instance-attribute

num_kv_heads = num_key_value_heads

num_kv_heads_local instance-attribute

num_kv_heads_local = num_kv_heads // tp_size

to_k instance-attribute

to_k = ColumnParallelLinear(
    hidden_size,
    num_kv_heads * head_dim,
    bias=False,
    gather_output=False,
    return_bias=False,
    quant_config=quant_config,
    prefix=f"{prefix}.to_k",
)

to_out instance-attribute

to_out = RowParallelLinear(
    num_heads * head_dim,
    hidden_size,
    bias=False,
    input_is_parallel=True,
    return_bias=False,
    quant_config=quant_config,
    prefix=f"{prefix}.to_out",
)

to_q instance-attribute

to_q = ColumnParallelLinear(
    hidden_size,
    num_heads * head_dim,
    bias=False,
    gather_output=False,
    return_bias=False,
    quant_config=quant_config,
    prefix=f"{prefix}.to_q",
)

to_v instance-attribute

to_v = ColumnParallelLinear(
    hidden_size,
    num_kv_heads * head_dim,
    bias=False,
    gather_output=False,
    return_bias=False,
    quant_config=quant_config,
    prefix=f"{prefix}.to_v",
)

forward

forward(
    hidden_states: Tensor,
    k_und: Tensor,
    v_und: Tensor,
    freqs_cos: Tensor,
    freqs_sin: Tensor,
) -> Tensor

Parameters:

Name Type Description Default
hidden_states Tensor

[B, S_gen_local, hidden_size] (may be sequence-sharded)

required
k_und Tensor

[B, S_und, H_kv_local, D] pre-computed UND keys (TP-sharded, post-norm/RoPE)

required
v_und Tensor

[B, S_und, H_kv_local, D] pre-computed UND values (TP-sharded)

required
freqs_cos Tensor

[B, S_gen_local, 1, D]

required
freqs_sin Tensor

[B, S_gen_local, 1, D]

required

Cosmos3GatedMLP

Bases: Module

down_proj instance-attribute

down_proj = RowParallelLinear(
    intermediate_size,
    hidden_size,
    bias=False,
    input_is_parallel=True,
    return_bias=False,
    quant_config=quant_config,
    prefix=f"{prefix}.down_proj",
)

gate_proj instance-attribute

gate_proj = ColumnParallelLinear(
    hidden_size,
    intermediate_size,
    bias=False,
    gather_output=False,
    return_bias=False,
    quant_config=quant_config,
    prefix=f"{prefix}.gate_proj",
)

up_proj instance-attribute

up_proj = ColumnParallelLinear(
    hidden_size,
    intermediate_size,
    bias=False,
    gather_output=False,
    return_bias=False,
    quant_config=quant_config,
    prefix=f"{prefix}.up_proj",
)

forward

forward(x: Tensor) -> Tensor

Cosmos3GenDecoderLayer

Bases: Module

Generation pathway decoder layer: cross-attention (to UND K/V) + MLP.

cross_attention instance-attribute

cross_attention = Cosmos3CrossAttention(
    hidden_size=hidden_size,
    num_attention_heads=num_attention_heads,
    num_key_value_heads=num_key_value_heads,
    head_dim=head_dim,
    rms_norm_eps=rms_norm_eps,
    quant_config=quant_config,
    prefix=f"{prefix}.cross_attention",
)

input_layernorm instance-attribute

input_layernorm = RMSNorm(hidden_size, eps=rms_norm_eps)

layer_idx instance-attribute

layer_idx = layer_idx

mlp instance-attribute

mlp = Cosmos3GatedMLP(
    hidden_size=hidden_size,
    intermediate_size=intermediate_size,
    quant_config=quant_config,
    prefix=f"{prefix}.mlp",
)

post_attention_layernorm instance-attribute

post_attention_layernorm = RMSNorm(
    hidden_size, eps=rms_norm_eps
)

forward

forward(
    hidden_states: Tensor,
    *,
    k_und: Tensor | None = None,
    v_und: Tensor | None = None,
    freqs_cos: Tensor | None = None,
    freqs_sin: Tensor | None = None,
    cached_kv: list[tuple[Tensor, Tensor]] | None = None,
    freqs_gen: tuple[Tensor, Tensor] | None = None,
) -> Tensor

Cosmos3GenSPPrepare

Bases: Module

Module boundary used by _sp_plan to shard GEN states and RoPE together.

forward

forward(
    hidden_gen: Tensor, freqs_cos: Tensor, freqs_sin: Tensor
) -> tuple[Tensor, Tensor, Tensor]

Cosmos3LanguageModel

Bases: Module

Understanding pathway: a standard causal LM that processes text tokens.

Returns per-layer K/V tensors for the generation pathway's cross-attention. The UND pathway is independent of the denoising step, so its K/V can be computed once and reused across all sampling steps.

embed_tokens instance-attribute

embed_tokens = Embedding(vocab_size, hidden_size)

layers instance-attribute

layers = ModuleList(
    [
        (
            Cosmos3UndDecoderLayer(
                hidden_size=hidden_size,
                intermediate_size=intermediate_size,
                num_attention_heads=num_attention_heads,
                num_key_value_heads=num_key_value_heads,
                head_dim=head_dim,
                rms_norm_eps=rms_norm_eps,
                quant_config=quant_config,
                prefix=f"{prefix}.layers.{i}",
            )
        )
        for i in (range(num_hidden_layers))
    ]
)

norm instance-attribute

norm = RMSNorm(hidden_size, eps=rms_norm_eps)

rotary_emb instance-attribute

rotary_emb = Qwen3VLTextRotaryEmbedding(
    head_dim=head_dim,
    rope_theta=rope_theta,
    mrope_section=mrope_section,
)

forward

forward(
    text_ids: Tensor, freqs: tuple[Tensor, Tensor]
) -> list[tuple[Tensor, Tensor]]

Parameters:

Name Type Description Default
text_ids Tensor

[B, S] token IDs

required
freqs tuple[Tensor, Tensor]

(cos, sin) each [B, S, 1, D]

required

Returns:

Type Description
list[tuple[Tensor, Tensor]]

List of (K, V) per layer, each [B, S, H_kv, D].

No padding mask is applied: with right-padding + causal self-attention, real query positions only attend to real keys, and the caller trims pad K/V via max_real_len before the GEN cross-attention sees them.

Cosmos3UndDecoderLayer

Bases: Module

Understanding pathway decoder layer: causal self-attention + MLP.

input_layernorm instance-attribute

input_layernorm = RMSNorm(hidden_size, eps=rms_norm_eps)

mlp instance-attribute

mlp = Cosmos3GatedMLP(
    hidden_size=hidden_size,
    intermediate_size=intermediate_size,
    quant_config=quant_config,
    prefix=f"{prefix}.mlp",
)

post_attention_layernorm instance-attribute

post_attention_layernorm = RMSNorm(
    hidden_size, eps=rms_norm_eps
)

self_attn instance-attribute

self_attn = Cosmos3CausalAttention(
    hidden_size=hidden_size,
    num_attention_heads=num_attention_heads,
    num_key_value_heads=num_key_value_heads,
    head_dim=head_dim,
    rms_norm_eps=rms_norm_eps,
    quant_config=quant_config,
    prefix=f"{prefix}.self_attn",
)

forward

forward(
    hidden_states: Tensor, freqs: tuple[Tensor, Tensor]
) -> tuple[Tensor, Tensor, Tensor]

Returns (hidden_states, K, V) where K/V are for GEN cross-attention.

Cosmos3VFMTransformer

Bases: Module

Cosmos3 VFM Transformer: UND language model + GEN denoising layers.

The UND pathway runs once per generation (K/V cached). The GEN pathway runs at each denoising step.

Layerwise offloading uses gen_layers as the block container.

Sequence parallelism uses _sp_plan to shard/gather the GEN pathway at module boundaries. Cosmos3CrossAttention checks forward_context.sp_active at runtime and routes to the framework Attention layer (with Ulysses all-to-all) or plain SDPA accordingly.

action_dim instance-attribute

action_dim = int(
    action_dim_value if action_dim_value is not None else 64
)

action_gen instance-attribute

action_gen = (
    _as_bool(action_gen_value)
    if action_gen_value is not None
    else False
)

action_modality_embed instance-attribute

action_modality_embed = Parameter(
    zeros(hidden_size, dtype=dtype)
)

action_proj_in instance-attribute

action_proj_in = DomainAwareLinear(
    action_dim,
    hidden_size,
    num_embodiment_domains,
    dtype=dtype,
)

action_proj_out instance-attribute

action_proj_out = DomainAwareLinear(
    hidden_size,
    action_dim,
    num_embodiment_domains,
    dtype=dtype,
)

audio_modality_embed instance-attribute

audio_modality_embed = Parameter(zeros(hidden_size))

audio_proj_in instance-attribute

audio_proj_in = Linear(sound_dim, hidden_size)

audio_proj_out instance-attribute

audio_proj_out = Linear(hidden_size, sound_dim)

base_fps instance-attribute

base_fps = float(
    _tf_config_get(model_config, "base_fps", 24.0)
)

cached_freqs_gen instance-attribute

cached_freqs_gen: tuple[Tensor, Tensor] | None = None

cached_kv instance-attribute

cached_kv: list[tuple[Tensor, Tensor]] | None = None

device property

device: device

enable_fps_modulation instance-attribute

enable_fps_modulation = bool(
    _tf_config_get(
        model_config, "enable_fps_modulation", True
    )
)

gen_layers instance-attribute

gen_layers = ModuleList(
    [
        (
            Cosmos3GenDecoderLayer(
                layer_idx=i,
                hidden_size=hidden_size,
                intermediate_size=intermediate_size,
                num_attention_heads=num_attention_heads,
                num_key_value_heads=num_key_value_heads,
                head_dim=head_dim,
                rms_norm_eps=rms_norm_eps,
                quant_config=quant_config,
                prefix=f"gen_layers.{i}",
            )
        )
        for i in (range(num_hidden_layers))
    ]
)

gen_sp_gather instance-attribute

gen_sp_gather = Identity()

gen_sp_prepare instance-attribute

gen_sp_prepare = Cosmos3GenSPPrepare()

head_dim instance-attribute

head_dim = int(
    _tf_config_get(model_config, "head_dim", 128)
)

hidden_size instance-attribute

hidden_size = int(
    _tf_config_get(model_config, "hidden_size", 4096)
)

intermediate_size instance-attribute

intermediate_size = int(
    _tf_config_get(model_config, "intermediate_size", 12288)
)

language_model instance-attribute

language_model = Cosmos3LanguageModel(
    hidden_size=hidden_size,
    intermediate_size=intermediate_size,
    num_hidden_layers=num_hidden_layers,
    num_attention_heads=num_attention_heads,
    num_key_value_heads=num_key_value_heads,
    head_dim=head_dim,
    vocab_size=vocab_size,
    rms_norm_eps=rms_norm_eps,
    rope_theta=rope_theta,
    mrope_section=mrope_section,
    quant_config=quant_config,
    prefix="language_model",
)

latent_channel_size instance-attribute

latent_channel_size = int(
    _tf_config_get(model_config, "latent_channel", 48)
)

latent_patch_size instance-attribute

latent_patch_size = int(
    _tf_config_get(model_config, "latent_patch_size", 2)
)

mrope_section instance-attribute

mrope_section = list(get('mrope_section', [24, 20, 20]))

norm_moe_gen instance-attribute

norm_moe_gen = RMSNorm(hidden_size, eps=rms_norm_eps)

num_attention_heads instance-attribute

num_attention_heads = int(
    _tf_config_get(model_config, "num_attention_heads", 32)
)

num_embodiment_domains instance-attribute

num_embodiment_domains = int(
    _od_config_get(od_config, "num_embodiment_domains", 32)
)

num_hidden_layers instance-attribute

num_hidden_layers = int(
    _tf_config_get(model_config, "num_hidden_layers", 36)
)

num_key_value_heads instance-attribute

num_key_value_heads = int(
    _tf_config_get(model_config, "num_key_value_heads", 8)
)

packed_modules_mapping class-attribute instance-attribute

packed_modules_mapping = {}

patch_latent_dim instance-attribute

patch_latent_dim = (
    latent_patch_size**2 * latent_channel_size
)

proj_in instance-attribute

proj_in = Linear(patch_latent_dim, hidden_size)

proj_out instance-attribute

proj_out = Linear(hidden_size, patch_latent_dim)

rms_norm_eps instance-attribute

rms_norm_eps = float(
    _tf_config_get(model_config, "rms_norm_eps", 1e-06)
)

rope_theta instance-attribute

rope_theta = float(
    _tf_config_get(model_config, "rope_theta", 5000000)
)

sound_dim instance-attribute

sound_dim = sound_dim

sound_gen instance-attribute

sound_gen = sound_gen

sound_latent_fps instance-attribute

sound_latent_fps = sound_latent_fps

temporal_compression_factor instance-attribute

temporal_compression_factor = int(
    temporal_compression_factor
)

temporal_compression_factor_sound instance-attribute

temporal_compression_factor_sound = int(
    _tf_config_get(
        model_config, "temporal_compression_factor_sound", 1
    )
)

temporal_modality_margin instance-attribute

temporal_modality_margin = int(
    _tf_config_get(
        model_config,
        "unified_3d_mrope_temporal_modality_margin",
        15000,
    )
)

time_embedder instance-attribute

time_embedder = TimestepEmbedder(
    hidden_size, target_dtype=dtype
)

timestep_scale instance-attribute

timestep_scale = float(
    _tf_config_get(model_config, "timestep_scale", 0.001)
)

vocab_size instance-attribute

vocab_size = int(
    _tf_config_get(model_config, "vocab_size", 151936)
)

forward

forward(
    hidden_states: Tensor,
    timestep: Tensor,
    text_ids: Tensor,
    text_mask: Tensor,
    video_shape: tuple[int, int, int],
    fps: float | None = None,
    action_latents: Tensor | None = None,
    action_domain_ids: Tensor | None = None,
    action_noisy_mask: Tensor | None = None,
    action_start_frame_offset: int = 1,
    action_fps: float | None = None,
    sound_latents: Tensor | None = None,
    noisy_frame_mask: Tensor | None = None,
    **kwargs,
) -> Tensor | tuple[Tensor, ...]

Parameters:

Name Type Description Default
hidden_states Tensor

[B, C, t, h, w] noisy latents

required
timestep Tensor

[B] diffusion timestep

required
text_ids Tensor

[B, S_text] tokenized text

required
text_mask Tensor

[B, S_text] attention mask (1=real, 0=pad)

required
video_shape tuple[int, int, int]

(t, h, w) in latent space

required
fps float | None

video frame rate for temporal mRoPE modulation

None
action_latents Tensor | None

Optional [B, T_action, D_action] noisy action latents.

None
action_domain_ids Tensor | None

Optional [B] embodiment domain IDs for action projections.

None
action_noisy_mask Tensor | None

Optional [B, T_action, 1] mask where 1=noisy action token and 0=clean conditioned token.

None
sound_latents Tensor | None

Optional [B, C_sound, T_sound] noisy sound latents.

None
noisy_frame_mask Tensor | None

Optional [B, 1, t, 1, 1] mask where 1=noisy (add timestep embedding, predict velocity) and 0=conditioned (clean context, skip timestep embedding). None means all frames noisy (T2V mode).

None

Returns:

Type Description
Tensor | tuple[Tensor, ...]

[B, C, t, h, w] velocity prediction, or

Tensor | tuple[Tensor, ...]

tuple outputs in video, action, sound order when extra modalities are provided.

pack_action

pack_action(action_latents: Tensor) -> Tensor

Validate and return action latents as [B, T_action, D_action] tokens.

pack_sound

pack_sound(sound_latents: Tensor) -> Tensor

[B, C_sound, T_sound] -> [B, T_sound, C_sound].

patchify

patchify(latents: Tensor, t: int, h: int, w: int) -> Tensor

[B, C, t, h, w] -> [B, thpwp, ppC], padding h/w if needed.

post_load_weights

post_load_weights() -> None

Post-load processing: ensure correct dtypes.

reset_cache

reset_cache() -> None

unpack_action staticmethod

unpack_action(tokens: Tensor) -> Tensor

Return [B, T_action, D_action] action predictions.

unpack_sound staticmethod

unpack_sound(tokens: Tensor) -> Tensor

[B, T_sound, C_sound] -> [B, C_sound, T_sound].

unpatchify

unpatchify(
    tokens: Tensor, t: int, h: int, w: int
) -> Tensor

[B, thpwp, ppC] -> [B, C, t, h, w], cropping padding if needed.

DomainAwareLinear

Bases: Module

Linear projection with one weight/bias pair per action embodiment domain.

bias instance-attribute

bias = Embedding(num_domains, output_size, dtype=dtype)

fc instance-attribute

fc = Embedding(
    num_domains, output_size * input_size, dtype=dtype
)

input_size instance-attribute

input_size = int(input_size)

num_domains instance-attribute

num_domains = int(num_domains)

output_size instance-attribute

output_size = int(output_size)

forward

forward(x: Tensor, domain_id: Tensor) -> Tensor

Qwen3VLTextRotaryEmbedding

Bases: Module

Multi-dimensional rotary position embedding for Qwen3-VL.

attention_scaling instance-attribute

attention_scaling = 1.0

head_dim instance-attribute

head_dim = head_dim

mrope_section instance-attribute

mrope_section = mrope_section

rope_theta instance-attribute

rope_theta = rope_theta

apply_interleaved_mrope

apply_interleaved_mrope(
    freqs: Tensor, mrope_section: list[int]
) -> Tensor

Reorganize from chunked [TTT...HHH...WWW] to interleaved [THTHW...].

forward

forward(
    x: Tensor, position_ids: Tensor
) -> tuple[Tensor, Tensor]

TimestepEmbedder

Bases: Module

Embeds scalar timesteps into vector representations.

act instance-attribute

act = SiLU()

frequency_embedding_size instance-attribute

frequency_embedding_size = frequency_embedding_size

hidden_size instance-attribute

hidden_size = hidden_size

linear_1 instance-attribute

linear_1 = Linear(
    frequency_embedding_size, hidden_size, bias=True
)

linear_2 instance-attribute

linear_2 = Linear(hidden_size, hidden_size, bias=True)

forward

forward(t: Tensor) -> Tensor

compute_mrope_position_ids_action

compute_mrope_position_ids_action(
    grid_t: int,
    temporal_offset: int | float,
    action_fps: float | None,
    base_fps: float = 24.0,
    base_temporal_compression_factor: int = 4,
    enable_fps_modulation: bool = True,
    start_frame_offset: int = 1,
) -> tuple[Tensor, int | float]

Generate mRoPE IDs for action tokens as a frame-rate (T, 1, 1) grid.

compute_mrope_position_ids_sound

compute_mrope_position_ids_sound(
    grid_t: int,
    temporal_offset: int | float,
    sound_latent_fps: float,
    base_fps: float = 24.0,
    temporal_compression_factor_sound: int = 1,
    enable_fps_modulation: bool = True,
) -> tuple[Tensor, int | float]

Generate mRoPE IDs for sound tokens as a (T, 1, 1) grid.

compute_mrope_position_ids_text

compute_mrope_position_ids_text(
    num_tokens: int, temporal_offset: int
) -> tuple[Tensor, int]

Generate 3D mRoPE position IDs for text tokens.

Text tokens: all three axes (t, h, w) share the same monotonically increasing position IDs.

compute_mrope_position_ids_vision

compute_mrope_position_ids_vision(
    grid_t: int,
    grid_h: int,
    grid_w: int,
    temporal_offset: int | float,
    fps: float | None = None,
    base_fps: float = 24.0,
    temporal_compression_factor: int = 4,
    base_temporal_compression_factor: int | None = None,
    enable_fps_modulation: bool = True,
    start_frame_offset: int = 0,
) -> tuple[Tensor, int | float]

Generate 3D mRoPE position IDs for vision tokens.

Creates a (t, h, w) position grid with spatial indices reset per segment (Qwen3VL-style). Flattened in t-major order.

resolve_sound_gen

resolve_sound_gen(od_config: Any) -> bool

Capability gate shared by the pipeline and transformer.

Explicit sound_gen flag wins (including an explicit False); otherwise infer from the presence of any sound-width key in od_config.