Skip to content

vllm_omni.model_executor.layers.timestep_embedding

Shared timestep embedding primitives for diffusion models.

  • SinusPositionEmbedding: sin/cos positional encoding (TTS DiT/CFM). Used by Qwen3-TTS, Qwen2.5-Omni, CosyVoice3, Ming-Flash-Omni.
  • DiTTimestepEmbedding: SinusPosEmb + Linear + SiLU + Linear MLP. Used by the same models as above.
  • timestep_embedding(): standalone function (GLIDE/DiT convention). Used by Bagel, NextStep, Z-Image, HunyuanImage3.

DiTTimestepEmbedding

Bases: Module

Timestep conditioning: SinusPositionEmbedding + Linear + SiLU + Linear.

Parameters:

Name Type Description Default
dim int

Hidden dimension (output size).

required
freq_embed_dim int

Sinusoidal embedding dimension (input to MLP).

256

time_embed instance-attribute

time_embed = SinusPositionEmbedding(freq_embed_dim)

time_mlp instance-attribute

time_mlp = Sequential(
    Linear(freq_embed_dim, dim), SiLU(), Linear(dim, dim)
)

forward

forward(timestep: Tensor) -> Tensor

SinusPositionEmbedding

Bases: Module

Sinusoidal position embedding for scalar timesteps.

Maps scalar timestep values to dim-dimensional embeddings using the standard log-spaced frequency formula from DDPM/DiT.

Parameters:

Name Type Description Default
dim int

Output embedding dimension (must be even).

required

dim instance-attribute

dim = dim

forward

forward(x: Tensor, scale: float = 1000.0) -> Tensor

Parameters:

Name Type Description Default
x Tensor

(N,) scalar timesteps.

required
scale float

Frequency scaling factor.

1000.0

Returns:

Type Description
Tensor

(N, dim) sinusoidal embeddings, cast to the input dtype.

timestep_embedding

timestep_embedding(
    t: Tensor, dim: int, max_period: float = 10000.0
) -> Tensor

Create sinusoidal timestep embeddings (GLIDE/DiT convention).

Produces cos-then-sin embeddings with log-spaced frequencies.

Parameters:

Name Type Description Default
t Tensor

(N,) 1-D tensor of timestep indices (may be fractional).

required
dim int

Output embedding dimension.

required
max_period float

Controls the minimum frequency.

10000.0

Returns:

Type Description
Tensor

(N, dim) tensor of positional embeddings.