Skip to content

vllm_omni.diffusion.models.bagel.bagel_transformer

logger module-attribute

logger = init_logger(__name__)

Bagel

Bases: Module

base_model_prefix class-attribute instance-attribute

base_model_prefix = 'bagel'

config instance-attribute

config = config

config_class class-attribute instance-attribute

config_class = BagelConfig

connector instance-attribute

connector = MLPconnector(
    self.vit_hidden_size,
    self.hidden_size,
    config.connector_act,
    quant_config=quant_config,
    prefix=f"{prefix}.connector",
)

get_flattened_position_ids instance-attribute

get_flattened_position_ids = (
    get_flattened_position_ids_extrapolate
)

hidden_size instance-attribute

hidden_size = config.llm_config.hidden_size

language_model instance-attribute

language_model = language_model

latent_channel instance-attribute

latent_channel = config.vae_config.z_channels

latent_downsample instance-attribute

latent_downsample = (
    config.vae_config.downsample * config.latent_patch_size
)

latent_patch_size instance-attribute

latent_patch_size = config.latent_patch_size

latent_pos_embed instance-attribute

latent_pos_embed = PositionEmbedding(
    self.max_latent_size, self.hidden_size
)

llm2vae instance-attribute

llm2vae = nn.Linear(self.hidden_size, self.patch_latent_dim)

max_latent_size instance-attribute

max_latent_size = config.max_latent_size

num_heads instance-attribute

num_heads = config.llm_config.num_attention_heads

parallel_config instance-attribute

parallel_config = parallel_config

patch_latent_dim instance-attribute

patch_latent_dim = (
    self.latent_patch_size**2 * self.latent_channel
)

time_embedder instance-attribute

time_embedder = TimestepEmbedder(self.hidden_size)

timestep_shift instance-attribute

timestep_shift = config.timestep_shift

use_moe instance-attribute

use_moe = 'Mo' in config.llm_config.layer_module

vae2llm instance-attribute

vae2llm = nn.Linear(self.patch_latent_dim, self.hidden_size)

vit_hidden_size instance-attribute

vit_hidden_size = config.vit_config.hidden_size

vit_max_num_patch_per_side instance-attribute

vit_max_num_patch_per_side = (
    config.vit_max_num_patch_per_side
)

vit_model instance-attribute

vit_model = vit_model

vit_patch_size instance-attribute

vit_patch_size = config.vit_config.patch_size

vit_pos_embed instance-attribute

vit_pos_embed = PositionEmbedding(
    self.vit_max_num_patch_per_side, self.hidden_size
)

forward

forward(
    x_t: Tensor,
    timestep: LongTensor,
    packed_vae_token_indexes: LongTensor,
    packed_vae_position_ids: LongTensor,
    packed_text_ids: LongTensor,
    packed_text_indexes: LongTensor,
    packed_position_ids: LongTensor,
    packed_seqlens: IntTensor,
    past_key_values: NaiveCache,
    cfg_renorm_min: float = 0.0,
    cfg_renorm_type: str = "global",
    cfg_text_scale: float = 1.0,
    cfg_img_scale: float = 1.0,
    cfg_branches: dict | None = None,
)

forward_cache_update_text

forward_cache_update_text(
    past_key_values: NaiveCache,
    packed_text_ids: IntTensor,
    packed_text_position_ids: LongTensor,
    text_token_lens: LongTensor,
)

forward_cache_update_vae

forward_cache_update_vae(
    vae_model,
    past_key_values: NaiveCache,
    padded_images: Tensor,
    patchified_vae_latent_shapes: list,
    packed_vae_position_ids: LongTensor,
    packed_timesteps: Tensor,
    packed_vae_token_indexes: LongTensor,
    packed_text_ids: LongTensor,
    packed_text_indexes: LongTensor,
    packed_position_ids: LongTensor,
    packed_seqlens: IntTensor,
)

forward_cache_update_vit

forward_cache_update_vit(
    past_key_values: NaiveCache,
    packed_text_ids: LongTensor,
    packed_text_indexes: LongTensor,
    packed_vit_tokens: Tensor,
    packed_vit_token_indexes: LongTensor,
    packed_vit_position_ids: LongTensor,
    vit_token_seqlens: IntTensor,
    packed_position_ids: LongTensor,
    packed_seqlens: IntTensor,
)

forward_single_branch

forward_single_branch(
    x_t: Tensor,
    timestep: LongTensor,
    packed_vae_token_indexes: LongTensor,
    packed_vae_position_ids: LongTensor,
    packed_text_ids: LongTensor,
    packed_text_indexes: LongTensor,
    packed_position_ids: LongTensor,
    packed_seqlens: IntTensor,
    past_key_values: NaiveCache,
) -> Tensor

Run a single-branch forward pass (no CFG batching).

Used by CFG parallel mode where each rank computes one branch. Returns the velocity v_t for the given branch. Supports Ulysses / Ring SP when parallel_config.sequence_parallel_size > 1.

generate_image

generate_image(
    packed_text_ids: LongTensor,
    packed_text_indexes: LongTensor,
    packed_init_noises: Tensor,
    packed_vae_position_ids: LongTensor,
    packed_vae_token_indexes: LongTensor,
    packed_seqlens: IntTensor,
    packed_position_ids: LongTensor,
    past_key_values: NaiveCache,
    num_timesteps: int = 24,
    timestep_shift: float = 1.0,
    cfg_renorm_min: float = 0.0,
    cfg_renorm_type: str = "global",
    cfg_interval: tuple[float, float] = [0, 1],
    cfg_text_scale: float = 1.0,
    cfg_text_packed_position_ids: LongTensor | None = None,
    cfg_text_past_key_values: NaiveCache | None = None,
    cfg_img_scale: float = 1.0,
    cfg_img_packed_position_ids: LongTensor | None = None,
    cfg_img_past_key_values: NaiveCache | None = None,
    return_trajectory_latents: bool = False,
    scheduler: object | None = None,
    scheduler_kwargs: dict | None = None,
    frame_condition_token_indexes: LongTensor | None = None,
)

generate_text

generate_text(
    past_key_values: NaiveCache,
    packed_start_tokens: LongTensor,
    packed_query_position_ids: LongTensor,
    max_length: int,
    do_sample: bool = False,
    temperature: float = 1.0,
    end_token_id: int | None = None,
)

Autoregressive text generation (ported from original BAGEL).

Decodes tokens one at a time, appending to past_key_values until max_length is reached or end_token_id is generated.

prepare_input

prepare_input(
    curr_kvlens, curr_rope, image_sizes, new_token_ids=None
)

prepare_prompts

prepare_prompts(
    curr_kvlens,
    curr_rope,
    prompts,
    tokenizer,
    new_token_ids,
)

prepare_start_tokens

prepare_start_tokens(curr_kvlens, curr_rope, new_token_ids)

Prepare start tokens for autoregressive text generation.

Ported from the original BAGEL Bagel.prepare_start_tokens.

prepare_vae_images

prepare_vae_images(
    curr_kvlens,
    curr_rope,
    images,
    transforms,
    new_token_ids,
    timestep=0,
)

prepare_vae_latent

prepare_vae_latent(
    curr_kvlens, curr_rope, image_sizes, new_token_ids
)

prepare_vae_latent_cfg

prepare_vae_latent_cfg(curr_kvlens, curr_rope, image_sizes)

prepare_vit_images

prepare_vit_images(
    curr_kvlens,
    curr_rope,
    images,
    transforms,
    new_token_ids,
)

BagelMLP

Bases: Module

FFN with Mixture-of-Tokens routing via MoT parallel linear layers.

gate_proj + up_proj are fused into a single MoTMergedColumnParallelLinear. down_proj uses MoTRowParallelLinear. Both layers hold text weights on self and vae weights on self.gen_exp, routing by text_indices / vae_indices.

act_fn instance-attribute

act_fn = SiluAndMul()

down_proj instance-attribute

down_proj = RowParallelLinear(
    input_size=intermediate_size,
    output_size=hidden_size,
    bias=False,
    input_is_parallel=True,
    quant_config=quant_config,
    prefix=f"{prefix}.down_proj",
)

gate_up_proj instance-attribute

gate_up_proj = MergedColumnParallelLinear(
    input_size=hidden_size,
    output_sizes=[intermediate_size, intermediate_size],
    bias=False,
    gather_output=False,
    quant_config=quant_config,
    prefix=f"{prefix}.gate_up_proj",
)

intermediate_size instance-attribute

intermediate_size = intermediate_size

forward

forward(x: Tensor) -> Tensor

BagelRotaryEmbedding

Bases: Module

Standalone rotary embedding that generates cos/sin from position ids.

Replaces HuggingFace's Qwen2RotaryEmbedding while preserving full rope_scaling support. When config.rope_scaling is set (e.g. linear, dynamic-NTK, YaRN, …), we delegate the inv_freq / attention_scaling computation to HF's ROPE_INIT_FUNCTIONS so that the frequency basis and scaling factor are identical to the original checkpoint.

For Qwen2.5-VL-style multimodal RoPE (rope_scaling.rope_type == "mrope") the inv_freq basis is the standard default-rope one; the difference is that position ids are 3-D (t, h, w) per token and the mrope_section describes how the head dimension is split across axes. This module accepts either 2-D scalar position ids (B, S) or 3-D multimodal position ids (B, 3, S) and dispatches accordingly so the same module works for both BAGEL (1-D rope) and Lance (Qwen2.5-VL mrope). This module has no learnable parameters.

attention_scaling instance-attribute

attention_scaling = 1.0

forward

forward(
    x: Tensor, position_ids: Tensor
) -> tuple[Tensor, Tensor]

Generate cos/sin embeddings for given position ids.

Parameters:

Name Type Description Default
x Tensor

Input tensor (only used for dtype inference).

required
position_ids Tensor

Either 2-D scalar (batch_size, seq_len) for plain 1-D RoPE, or 3-D multimodal (batch_size, 3, seq_len) for Qwen2.5-VL-style mRoPE. The latter is auto-detected from position_ids.ndim.

required

Returns:

Type Description
tuple[Tensor, Tensor]

cos, sin: Rotary embeddings, each of shape (batch_size, seq_len, dim).

BaseNavitOutputWithPast dataclass

Bases: ModelOutput

packed_query_sequence class-attribute instance-attribute

packed_query_sequence: FloatTensor = None

past_key_values class-attribute instance-attribute

past_key_values: NaiveCache | None = None

MLPconnector

Bases: Module

act instance-attribute

act = nn.GELU()

fc1 instance-attribute

fc1 = ColumnParallelLinear(
    input_dim,
    output_dim,
    bias=True,
    gather_output=False,
    quant_config=quant_config,
    prefix=f"{prefix}.fc1",
)

fc2 instance-attribute

fc2 = RowParallelLinear(
    output_dim,
    output_dim,
    bias=True,
    input_is_parallel=True,
    quant_config=quant_config,
    prefix=f"{prefix}.fc2",
)

forward

forward(x)

NaiveCache

key_cache instance-attribute

key_cache = {k: None for k in (range(num_layers))}

num_layers property

num_layers

seq_lens property

seq_lens

value_cache instance-attribute

value_cache = {k: None for k in (range(num_layers))}

PackedAttentionMoT

Bases: Module

Packed attention with Mixture-of-Tokens routing for understanding/generation.

Uses MoTQKVParallelLinear and MoTRowParallelLinear for tensor parallelism. Text and vae weights are held within the same MoT layer (text on self, vae on self.gen_exp). Token routing is driven by text_indices / vae_indices.

attn_causal instance-attribute

attn_causal = DiffusionAttention(
    num_heads=self.total_num_heads,
    head_size=self.head_dim,
    softmax_scale=1.0 / self.head_dim**0.5,
    causal=True,
    num_kv_heads=self.total_num_kv_heads,
)

attn_noncausal instance-attribute

attn_noncausal = DiffusionAttention(
    num_heads=self.total_num_heads,
    head_size=self.head_dim,
    softmax_scale=1.0 / self.head_dim**0.5,
    causal=False,
    num_kv_heads=self.total_num_kv_heads,
)

head_dim instance-attribute

head_dim = self.hidden_size // self.total_num_heads

hidden_size instance-attribute

hidden_size = config.hidden_size

k_norm instance-attribute

k_norm = MoTRMSNorm(
    self.head_dim, head_norm=True, eps=config.rms_norm_eps
)

kv_size instance-attribute

kv_size = self.num_kv_heads * self.head_dim

layer_idx instance-attribute

layer_idx = layer_idx

num_heads instance-attribute

num_heads = self.total_num_heads // tp_size

num_kv_heads instance-attribute

num_kv_heads = max(1, self.total_num_kv_heads // tp_size)

o_proj instance-attribute

o_proj = MoTRowParallelLinear(
    input_size=self.total_num_heads * self.head_dim,
    output_size=self.hidden_size,
    bias=False,
    quant_config=quant_config,
    prefix=f"{prefix}.o_proj",
)

parallel_config instance-attribute

parallel_config = parallel_config

q_norm instance-attribute

q_norm = MoTRMSNorm(
    self.head_dim, head_norm=True, eps=config.rms_norm_eps
)

q_size instance-attribute

q_size = self.num_heads * self.head_dim

qkv_proj instance-attribute

qkv_proj = MoTQKVParallelLinear(
    self.hidden_size,
    self.head_dim,
    self.total_num_heads,
    self.total_num_kv_heads,
    bias=True,
    vae_bias=True,
    quant_config=quant_config,
    prefix=f"{prefix}.qkv_proj",
)

rotary_op instance-attribute

rotary_op = RotaryEmbedding(is_neox_style=True)

total_num_heads instance-attribute

total_num_heads = config.num_attention_heads

total_num_kv_heads instance-attribute

total_num_kv_heads = config.num_key_value_heads

forward

forward(
    packed_query_sequence: Tensor,
    query_lens: Tensor,
    packed_query_position_embeddings: Tensor,
    past_key_values: NaiveCache | None = None,
    update_past_key_values=True,
    is_causal=True,
    mode="und",
    packed_vae_token_indexes=None,
    packed_text_indexes=None,
)

PositionEmbedding

Bases: Module

hidden_size instance-attribute

hidden_size = hidden_size

max_num_patch_per_side instance-attribute

max_num_patch_per_side = max_num_patch_per_side

pos_embed instance-attribute

pos_embed = nn.Parameter(
    torch.zeros(max_num_patch_per_side**2, hidden_size),
    requires_grad=False,
)

forward

forward(position_ids)

Qwen2MoTConfig

Bases: Qwen2Config

Configuration for Qwen2MoT (Mixture of Tokens) model.

This is fundamentally different from Qwen2, hence the distinct name.

keys_to_ignore_at_inference class-attribute instance-attribute

keys_to_ignore_at_inference = ['past_key_values']

layer_module instance-attribute

layer_module = layer_module

model_type class-attribute instance-attribute

model_type = 'qwen2_mot'

qk_norm instance-attribute

qk_norm = qk_norm

Qwen2MoTDecoderLayer

Bases: Module

hidden_size instance-attribute

hidden_size = config.hidden_size

input_layernorm instance-attribute

input_layernorm = MoTRMSNorm(
    config.hidden_size, eps=config.rms_norm_eps
)

layer_idx instance-attribute

layer_idx = layer_idx

mlp instance-attribute

mlp = BagelMLP(
    config.hidden_size,
    config.intermediate_size,
    config.hidden_act,
    quant_config=quant_config,
    prefix=f"{prefix}.mlp",
)

mlp_moe_gen instance-attribute

mlp_moe_gen = BagelMLP(
    config.hidden_size,
    config.intermediate_size,
    config.hidden_act,
    quant_config=quant_config,
    prefix=f"{prefix}.mlp_moe_gen",
)

post_attention_layernorm instance-attribute

post_attention_layernorm = MoTRMSNorm(
    config.hidden_size, eps=config.rms_norm_eps
)

self_attn instance-attribute

self_attn = attn_module(
    config,
    layer_idx,
    parallel_config=parallel_config,
    quant_config=quant_config,
    prefix=f"{prefix}.self_attn",
)

forward

forward(
    hidden_states: Tensor,
    encoder_hidden_states: Tensor | None = None,
    packed_query_sequence: Tensor | None = None,
    query_lens: Tensor = None,
    packed_query_position_embeddings: Tensor = None,
    past_key_values: NaiveCache | None = None,
    update_past_key_values=True,
    is_causal=True,
    mode="und",
    packed_vae_token_indexes=None,
    packed_text_indexes=None,
) -> BaseNavitOutputWithPast

Qwen2MoTForCausalLM

Bases: Qwen2PreTrainedModel

lm_head instance-attribute

lm_head = nn.Linear(
    config.hidden_size, config.vocab_size, bias=False
)

model instance-attribute

model = Qwen2MoTModel(
    config,
    parallel_config=parallel_config,
    quant_config=quant_config,
    prefix=f"{prefix}.model",
)

vocab_size instance-attribute

vocab_size = config.vocab_size

forward

forward(
    packed_query_sequence: Tensor | None = None,
    query_lens: Tensor | None = None,
    packed_query_position_ids: Tensor | None = None,
    past_key_values: NaiveCache | None = None,
    update_past_key_values=True,
    is_causal=True,
    mode="und",
    packed_vae_token_indexes=None,
    packed_text_indexes=None,
    packed_text_ids: Tensor | None = None,
    return_embeddings_only: bool = False,
) -> BaseNavitOutputWithPast

get_decoder

get_decoder()

get_input_embeddings

get_input_embeddings()

get_output_embeddings

get_output_embeddings()

load_weights

load_weights(
    weights: Iterable[tuple[str, Tensor]],
) -> set[str]

Load weights for MoT parallel layers.

Stacked parameter remapping (checkpoint name → model parameter): - q/k/v_proj → qkv_proj (text, shard q/k/v) - q/k/v_proj_moe_gen → qkv_proj.gen_exp (gen, shard q/k/v)

Direct remapping (no shard dimension): - o_proj_moe_gen → o_proj.gen_exp - {norm}_moe_gen.weight → {norm}.gen_weight (all MoTRMSNorm layers)

Text norm weights (input_layernorm.weight, q_norm.weight, etc.) and other names (embed_tokens, lm_head) pass through unchanged.

set_decoder

set_decoder(decoder)

set_input_embeddings

set_input_embeddings(value)

set_output_embeddings

set_output_embeddings(new_embeddings)

Qwen2MoTModel

Bases: Qwen2PreTrainedModel

embed_tokens instance-attribute

embed_tokens = VocabParallelEmbedding(
    config.vocab_size, config.hidden_size
)

layers instance-attribute

layers = nn.ModuleList(
    [
        (
            Qwen2MoTDecoderLayer(
                config,
                layer_idx,
                attn_module=PackedAttentionMoT,
                parallel_config=parallel_config,
                quant_config=quant_config,
                prefix=f"{prefix}.layers.{layer_idx}",
            )
        )
        for layer_idx in (range(config.num_hidden_layers))
    ]
)

norm instance-attribute

norm = MoTRMSNorm(
    config.hidden_size, eps=config.rms_norm_eps
)

padding_idx instance-attribute

padding_idx = config.pad_token_id

rotary_emb instance-attribute

rotary_emb = BagelRotaryEmbedding(config=config)

use_moe instance-attribute

use_moe = 'Mo' in config.layer_module

vocab_size instance-attribute

vocab_size = config.vocab_size

forward

forward(
    packed_query_sequence: Tensor | None = None,
    query_lens: Tensor | None = None,
    packed_query_position_ids: Tensor | None = None,
    past_key_values: NaiveCache | None = None,
    update_past_key_values=True,
    is_causal=True,
    mode="und",
    packed_vae_token_indexes=None,
    packed_text_indexes=None,
    packed_text_ids: Tensor | None = None,
    return_embeddings_only: bool = False,
) -> BaseNavitOutputWithPast

TimestepEmbedder

Bases: Module

Embeds scalar timesteps into vector representations.

frequency_embedding_size instance-attribute

frequency_embedding_size = frequency_embedding_size

mlp instance-attribute

mlp = nn.Sequential(
    nn.Linear(
        frequency_embedding_size, hidden_size, bias=True
    ),
    nn.SiLU(),
    nn.Linear(hidden_size, hidden_size, bias=True),
)

forward

forward(t)

get_1d_sincos_pos_embed_from_grid

get_1d_sincos_pos_embed_from_grid(embed_dim, pos)

embed_dim: output dimension for each position pos: a list of positions to be encoded: size (M,) out: (M, D)

get_2d_sincos_pos_embed

get_2d_sincos_pos_embed(
    embed_dim, grid_size, cls_token=False, extra_tokens=0
)

get_2d_sincos_pos_embed_from_grid

get_2d_sincos_pos_embed_from_grid(embed_dim, grid)

get_flattened_position_ids_extrapolate

get_flattened_position_ids_extrapolate(
    img_h, img_w, patch_size, max_num_patches_per_side
)

patchify

patchify(imgs, p)

imgs: (N, 3, H, W) or (3, H, W) x: (N, L, patch_size2 *3) or (L, patch_size2 *3)