Skip to content

vllm_omni.diffusion.models.helios.helios_transformer

logger module-attribute

logger = init_logger(__name__)

ColumnParallelGELU

Bases: Module

Column parallel linear with GELU activation.

approximate instance-attribute

approximate = approximate

proj instance-attribute

proj = ColumnParallelLinear(
    dim_in,
    dim_out,
    bias=bias,
    gather_output=False,
    return_bias=False,
    quant_config=quant_config,
)

forward

forward(x: Tensor) -> Tensor

DistributedRMSNorm

Bases: Module

RMSNorm that computes global RMS across tensor parallel ranks.

eps instance-attribute

eps = eps

weight instance-attribute

weight = Parameter(ones(hidden_size))

forward

forward(x: Tensor) -> Tensor

HeliosCrossAttention

Bases: Module

Optimized cross-attention for Helios.

attn instance-attribute

attn = Attention(
    num_heads=num_heads,
    head_size=head_dim,
    num_kv_heads=num_heads,
    softmax_scale=1.0 / head_dim**0.5,
    causal=False,
)

dim instance-attribute

dim = dim

dropout instance-attribute

dropout = Dropout(dropout)

head_dim instance-attribute

head_dim = head_dim

inner_dim instance-attribute

inner_dim = num_heads * head_dim

norm_k instance-attribute

norm_k = DistributedRMSNorm(tp_inner_dim, eps=eps)

norm_q instance-attribute

norm_q = DistributedRMSNorm(tp_inner_dim, eps=eps)

num_heads instance-attribute

num_heads = num_heads // tp_size

to_k instance-attribute

to_k = ColumnParallelLinear(
    dim,
    inner_dim,
    bias=True,
    gather_output=False,
    return_bias=False,
    quant_config=quant_config,
)

to_out instance-attribute

to_out = RowParallelLinear(
    inner_dim,
    dim,
    bias=True,
    input_is_parallel=True,
    return_bias=False,
    quant_config=quant_config,
)

to_q instance-attribute

to_q = ColumnParallelLinear(
    dim,
    inner_dim,
    bias=True,
    gather_output=False,
    return_bias=False,
    quant_config=quant_config,
)

to_v instance-attribute

to_v = ColumnParallelLinear(
    dim,
    inner_dim,
    bias=True,
    gather_output=False,
    return_bias=False,
    quant_config=quant_config,
)

tp_inner_dim instance-attribute

tp_inner_dim = num_heads * head_dim

forward

forward(
    hidden_states: Tensor,
    encoder_hidden_states: Tensor | None = None,
    encoder_key_value: tuple[Tensor, Tensor] | None = None,
) -> Tensor

project_kv

project_kv(
    encoder_hidden_states: Tensor,
) -> tuple[Tensor, Tensor]

HeliosFeedForward

Bases: Module

TP-enabled FeedForward network for Helios.

net_0 instance-attribute

net_0 = ColumnParallelGELU(
    dim,
    inner_dim,
    approximate="tanh",
    bias=bias,
    quant_config=quant_config,
)

net_1 instance-attribute

net_1 = Identity()

net_2 instance-attribute

net_2 = RowParallelLinear(
    inner_dim,
    dim_out,
    bias=bias,
    input_is_parallel=True,
    return_bias=False,
    quant_config=quant_config,
)

forward

forward(hidden_states: Tensor) -> Tensor

HeliosOutputNorm

Bases: Module

Output normalization that extracts only original_context_length tokens.

norm instance-attribute

norm = FP32LayerNorm(dim, eps, elementwise_affine=False)

scale_shift_table instance-attribute

scale_shift_table = Parameter(randn(1, 2, dim) / dim**0.5)

forward

forward(
    hidden_states: Tensor,
    temb: Tensor,
    original_context_length: int,
)

HeliosRotaryPosEmbed

Bases: Module

Helios-style 3D rotary position embeddings using explicit frame indices.

theta instance-attribute

theta = theta

forward

forward(frame_indices, height, width, device)

get_frequency_batched

get_frequency_batched(freqs_base, pos)

HeliosSelfAttention

Bases: Module

Optimized self-attention for Helios with history amplification support.

attn instance-attribute

attn = Attention(
    num_heads=num_heads,
    head_size=head_dim,
    num_kv_heads=num_kv_heads,
    softmax_scale=1.0 / head_dim**0.5,
    causal=False,
)

dim instance-attribute

dim = dim

dropout instance-attribute

dropout = Dropout(dropout)

head_dim instance-attribute

head_dim = head_dim

history_key_scale instance-attribute

history_key_scale = Parameter(ones(1))

history_scale_mode instance-attribute

history_scale_mode = history_scale_mode

inner_dim instance-attribute

inner_dim = num_heads * head_dim

is_amplify_history instance-attribute

is_amplify_history = is_amplify_history

max_scale instance-attribute

max_scale = 10.0

norm_k instance-attribute

norm_k = DistributedRMSNorm(tp_inner_dim, eps=eps)

norm_q instance-attribute

norm_q = DistributedRMSNorm(tp_inner_dim, eps=eps)

num_heads instance-attribute

num_heads = num_heads

num_kv_heads instance-attribute

num_kv_heads = num_kv_heads

to_out instance-attribute

to_out = RowParallelLinear(
    inner_dim,
    dim,
    bias=True,
    input_is_parallel=True,
    return_bias=False,
    quant_config=quant_config,
)

to_qkv instance-attribute

to_qkv = QKVParallelLinear(
    hidden_size=dim,
    head_size=head_dim,
    total_num_heads=num_heads,
    bias=True,
    quant_config=quant_config,
)

tp_inner_dim instance-attribute

tp_inner_dim = num_heads * head_dim

forward

forward(
    hidden_states: Tensor,
    rotary_emb: Tensor | None = None,
    original_context_length: int | None = None,
) -> Tensor

HeliosTimeTextEmbedding

Bases: Module

Combined time and text condition embeddings for Helios.

act_fn instance-attribute

act_fn = SiLU()

text_embedder instance-attribute

text_embedder = PixArtAlphaTextProjection(
    text_embed_dim, dim, act_fn="gelu_tanh"
)

time_embedder instance-attribute

time_embedder = TimestepEmbedding(
    in_channels=time_freq_dim, time_embed_dim=dim
)

time_proj instance-attribute

time_proj = Linear(dim, time_proj_dim)

timesteps_proj instance-attribute

timesteps_proj = Timesteps(
    num_channels=time_freq_dim,
    flip_sin_to_cos=True,
    downscale_freq_shift=0,
)

forward

forward(
    timestep: Tensor,
    encoder_hidden_states: Tensor | None = None,
    is_return_encoder_hidden_states: bool = True,
)

HeliosTransformer3DModel

Bases: Module

Optimized Helios Transformer model for video generation using vLLM layers.

Helios extends the Wan2.2 architecture with multi-term memory patches, guidance cross-attention, and chunked video generation support.

blocks instance-attribute

blocks = ModuleList(
    [
        (
            HeliosTransformerBlock(
                inner_dim,
                ffn_dim,
                num_attention_heads,
                eps,
                cross_attn_norm,
                guidance_cross_attn=guidance_cross_attn,
                is_amplify_history=is_amplify_history,
                history_scale_mode=history_scale_mode,
                quant_config=quant_config,
            )
        )
        for _ in (range(num_layers))
    ]
)

condition_embedder instance-attribute

condition_embedder = HeliosTimeTextEmbedding(
    dim=inner_dim,
    time_freq_dim=freq_dim,
    time_proj_dim=inner_dim * 6,
    text_embed_dim=text_dim,
)

config instance-attribute

config = type(
    "Config",
    (),
    {
        "patch_size": patch_size,
        "num_attention_heads": num_attention_heads,
        "attention_head_dim": attention_head_dim,
        "in_channels": in_channels,
        "out_channels": out_channels,
        "text_dim": text_dim,
        "freq_dim": freq_dim,
        "ffn_dim": ffn_dim,
        "num_layers": num_layers,
        "cross_attn_norm": cross_attn_norm,
        "qk_norm": qk_norm,
        "eps": eps,
        "added_kv_proj_dim": added_kv_proj_dim,
        "rope_dim": rope_dim,
        "rope_theta": rope_theta,
        "guidance_cross_attn": guidance_cross_attn,
        "zero_history_timestep": zero_history_timestep,
        "has_multi_term_memory_patch": has_multi_term_memory_patch,
        "is_amplify_history": is_amplify_history,
        "history_scale_mode": history_scale_mode,
    },
)()

dtype property

dtype: dtype

has_multi_term_memory_patch instance-attribute

has_multi_term_memory_patch = has_multi_term_memory_patch

inner_dim instance-attribute

inner_dim = inner_dim

norm_out instance-attribute

norm_out = HeliosOutputNorm(inner_dim, eps)

packed_modules_mapping class-attribute instance-attribute

packed_modules_mapping = {
    "to_qkv": ["to_q", "to_k", "to_v"]
}

patch_embedding instance-attribute

patch_embedding = Conv3dLayer(
    in_channels=in_channels,
    out_channels=inner_dim,
    kernel_size=patch_size,
    stride=patch_size,
)

patch_long instance-attribute

patch_long = Conv3dLayer(
    in_channels=in_channels,
    out_channels=inner_dim,
    kernel_size=(4, 8, 8),
    stride=(4, 8, 8),
)

patch_mid instance-attribute

patch_mid = Conv3dLayer(
    in_channels=in_channels,
    out_channels=inner_dim,
    kernel_size=(2, 4, 4),
    stride=(2, 4, 4),
)

patch_short instance-attribute

patch_short = Conv3dLayer(
    in_channels=in_channels,
    out_channels=inner_dim,
    kernel_size=(1, 2, 2),
    stride=(1, 2, 2),
)

proj_out instance-attribute

proj_out = Linear(
    inner_dim, out_channels * prod(patch_size)
)

rope instance-attribute

rope = HeliosRotaryPosEmbed(
    rope_dim=rope_dim, theta=rope_theta
)

zero_history_timestep instance-attribute

zero_history_timestep = zero_history_timestep

clear_cross_attention_cache

clear_cross_attention_cache() -> None

forward

forward(
    hidden_states: Tensor,
    timestep: LongTensor,
    encoder_hidden_states: Tensor,
    indices_hidden_states: Tensor | None = None,
    indices_latents_history_short: Tensor | None = None,
    indices_latents_history_mid: Tensor | None = None,
    indices_latents_history_long: Tensor | None = None,
    latents_history_short: Tensor | None = None,
    latents_history_mid: Tensor | None = None,
    latents_history_long: Tensor | None = None,
    return_dict: bool = True,
    attention_kwargs: dict[str, Any] | None = None,
) -> Tensor | Transformer2DModelOutput

load_weights

load_weights(
    weights: Iterable[tuple[str, Tensor]],
) -> set[str]

Load weights with QKV fusion, FFN remapping, and TP norm sharding.

HeliosTransformerBlock

Bases: Module

Transformer block with guidance cross-attention and history support.

attn1 instance-attribute

attn1 = HeliosSelfAttention(
    dim=dim,
    num_heads=num_heads,
    head_dim=head_dim,
    eps=eps,
    is_amplify_history=is_amplify_history,
    history_scale_mode=history_scale_mode,
    quant_config=quant_config,
)

attn2 instance-attribute

attn2 = HeliosCrossAttention(
    dim=dim,
    num_heads=num_heads,
    head_dim=head_dim,
    eps=eps,
    quant_config=quant_config,
)

ffn instance-attribute

ffn = HeliosFeedForward(
    dim=dim,
    inner_dim=ffn_dim,
    dim_out=dim,
    quant_config=quant_config,
)

guidance_cross_attn instance-attribute

guidance_cross_attn = guidance_cross_attn

norm1 instance-attribute

norm1 = FP32LayerNorm(dim, eps, elementwise_affine=False)

norm2 instance-attribute

norm2 = (
    FP32LayerNorm(dim, eps, elementwise_affine=True)
    if cross_attn_norm
    else Identity()
)

norm3 instance-attribute

norm3 = FP32LayerNorm(dim, eps, elementwise_affine=False)

scale_shift_table instance-attribute

scale_shift_table = Parameter(randn(1, 6, dim) / dim**0.5)

forward

forward(
    hidden_states: Tensor,
    encoder_hidden_states: Tensor,
    temb: Tensor,
    rotary_emb: Tensor,
    original_context_length: int | None = None,
    cross_attn_key_value: tuple[Tensor, Tensor]
    | None = None,
) -> Tensor

apply_rotary_emb_helios

apply_rotary_emb_helios(
    hidden_states: Tensor, freqs_cis: Tensor
) -> Tensor

Apply Helios-style rotary embeddings.

freqs_cis contains [cos_t, cos_y, cos_x, sin_t, sin_y, sin_x] concatenated along the last dimension, with shape [B, seq, D*2] where D = DT+DY+DX. hidden_states has shape [B, seq, H, head_dim].

center_down_sample_3d

center_down_sample_3d(x, kernel_size)

pad_for_3d_conv

pad_for_3d_conv(x, kernel_size)