Skip to content

vllm_omni.diffusion.models.bagel.bagel_transformer

Bagel

Bases: Module

base_model_prefix class-attribute instance-attribute

base_model_prefix = 'bagel'

config instance-attribute

config = config

config_class class-attribute instance-attribute

config_class = BagelConfig

connector instance-attribute

connector = MLPconnector(
    vit_hidden_size,
    hidden_size,
    connector_act,
    quant_config=quant_config,
    prefix=f"{prefix}.connector",
)

get_flattened_position_ids instance-attribute

get_flattened_position_ids = (
    get_flattened_position_ids_extrapolate
)

hidden_size instance-attribute

hidden_size = hidden_size

language_model instance-attribute

language_model = language_model

latent_channel instance-attribute

latent_channel = z_channels

latent_downsample instance-attribute

latent_downsample = downsample * latent_patch_size

latent_patch_size instance-attribute

latent_patch_size = latent_patch_size

latent_pos_embed instance-attribute

latent_pos_embed = PositionEmbedding(
    max_latent_size, hidden_size
)

llm2vae instance-attribute

llm2vae = Linear(hidden_size, patch_latent_dim)

max_latent_size instance-attribute

max_latent_size = max_latent_size

num_heads instance-attribute

num_heads = num_attention_heads

parallel_config instance-attribute

parallel_config = parallel_config

patch_latent_dim instance-attribute

patch_latent_dim = latent_patch_size ** 2 * latent_channel

time_embedder instance-attribute

time_embedder = TimestepEmbedder(hidden_size)

timestep_shift instance-attribute

timestep_shift = timestep_shift

use_moe instance-attribute

use_moe = 'Mo' in layer_module

vae2llm instance-attribute

vae2llm = Linear(patch_latent_dim, hidden_size)

vit_hidden_size instance-attribute

vit_hidden_size = hidden_size

vit_max_num_patch_per_side instance-attribute

vit_max_num_patch_per_side = vit_max_num_patch_per_side

vit_model instance-attribute

vit_model = vit_model

vit_patch_size instance-attribute

vit_patch_size = patch_size

vit_pos_embed instance-attribute

vit_pos_embed = PositionEmbedding(
    vit_max_num_patch_per_side, hidden_size
)

forward

forward(
    x_t: Tensor,
    timestep: LongTensor,
    packed_vae_token_indexes: LongTensor,
    packed_vae_position_ids: LongTensor,
    packed_text_ids: LongTensor,
    packed_text_indexes: LongTensor,
    packed_position_ids: LongTensor,
    packed_seqlens: IntTensor,
    past_key_values: NaiveCache,
    cfg_renorm_min: float = 0.0,
    cfg_renorm_type: str = "global",
    cfg_text_scale: float = 1.0,
    cfg_img_scale: float = 1.0,
    cfg_branches: dict | None = None,
)

forward_cache_update_text

forward_cache_update_text(
    past_key_values: NaiveCache,
    packed_text_ids: IntTensor,
    packed_text_position_ids: LongTensor,
    text_token_lens: LongTensor,
)

forward_cache_update_vae

forward_cache_update_vae(
    vae_model,
    past_key_values: NaiveCache,
    padded_images: Tensor,
    patchified_vae_latent_shapes: list,
    packed_vae_position_ids: LongTensor,
    packed_timesteps: Tensor,
    packed_vae_token_indexes: LongTensor,
    packed_text_ids: LongTensor,
    packed_text_indexes: LongTensor,
    packed_position_ids: LongTensor,
    packed_seqlens: IntTensor,
)

forward_cache_update_vit

forward_cache_update_vit(
    past_key_values: NaiveCache,
    packed_text_ids: LongTensor,
    packed_text_indexes: LongTensor,
    packed_vit_tokens: Tensor,
    packed_vit_token_indexes: LongTensor,
    packed_vit_position_ids: LongTensor,
    vit_token_seqlens: IntTensor,
    packed_position_ids: LongTensor,
    packed_seqlens: IntTensor,
)

forward_single_branch

forward_single_branch(
    x_t: Tensor,
    timestep: LongTensor,
    packed_vae_token_indexes: LongTensor,
    packed_vae_position_ids: LongTensor,
    packed_text_ids: LongTensor,
    packed_text_indexes: LongTensor,
    packed_position_ids: LongTensor,
    packed_seqlens: IntTensor,
    past_key_values: NaiveCache,
) -> Tensor

Run a single-branch forward pass (no CFG batching).

Used by CFG parallel mode where each rank computes one branch. Returns the velocity v_t for the given branch. Supports Ulysses / Ring SP when parallel_config.sequence_parallel_size > 1.

generate_image

generate_image(
    packed_text_ids: LongTensor,
    packed_text_indexes: LongTensor,
    packed_init_noises: Tensor,
    packed_vae_position_ids: LongTensor,
    packed_vae_token_indexes: LongTensor,
    packed_seqlens: IntTensor,
    packed_position_ids: LongTensor,
    past_key_values: NaiveCache,
    num_timesteps: int = 24,
    timestep_shift: float = 1.0,
    cfg_renorm_min: float = 0.0,
    cfg_renorm_type: str = "global",
    cfg_interval: tuple[float, float] = [0, 1],
    cfg_text_scale: float = 1.0,
    cfg_text_packed_position_ids: LongTensor | None = None,
    cfg_text_past_key_values: NaiveCache | None = None,
    cfg_img_scale: float = 1.0,
    cfg_img_packed_position_ids: LongTensor | None = None,
    cfg_img_past_key_values: NaiveCache | None = None,
    return_trajectory_latents: bool = False,
    scheduler: object | None = None,
    scheduler_kwargs: dict | None = None,
    frame_condition_token_indexes: LongTensor | None = None,
)

generate_text

generate_text(
    past_key_values: NaiveCache,
    packed_start_tokens: LongTensor,
    packed_query_position_ids: LongTensor,
    max_length: int,
    do_sample: bool = False,
    temperature: float = 1.0,
    end_token_id: int | None = None,
)

Autoregressive text generation (ported from original BAGEL).

Decodes tokens one at a time, appending to past_key_values until max_length is reached or end_token_id is generated.

prepare_input

prepare_input(
    curr_kvlens, curr_rope, image_sizes, new_token_ids=None
)

prepare_prompts

prepare_prompts(
    curr_kvlens,
    curr_rope,
    prompts,
    tokenizer,
    new_token_ids,
)

prepare_start_tokens

prepare_start_tokens(curr_kvlens, curr_rope, new_token_ids)

Prepare start tokens for autoregressive text generation.

Ported from the original BAGEL Bagel.prepare_start_tokens.

prepare_vae_images

prepare_vae_images(
    curr_kvlens,
    curr_rope,
    images,
    transforms,
    new_token_ids,
    timestep=0,
)

prepare_vae_latent

prepare_vae_latent(
    curr_kvlens, curr_rope, image_sizes, new_token_ids
)

prepare_vae_latent_cfg

prepare_vae_latent_cfg(curr_kvlens, curr_rope, image_sizes)

prepare_vit_images

prepare_vit_images(
    curr_kvlens,
    curr_rope,
    images,
    transforms,
    new_token_ids,
)

BagelMLP

Bases: Module

act_fn instance-attribute

act_fn = SiLU()

down_proj instance-attribute

down_proj = RowParallelLinear(
    intermediate_size,
    hidden_size,
    input_is_parallel=True,
    bias=False,
    quant_config=quant_config,
    prefix=f"{prefix}.down_proj",
)

gate_up_proj instance-attribute

gate_up_proj = MergedColumnParallelLinear(
    hidden_size,
    [intermediate_size, intermediate_size],
    bias=False,
    quant_config=quant_config,
    prefix=f"{prefix}.gate_up_proj",
)

forward

forward(x)

BagelRotaryEmbedding

Bases: Module

Standalone rotary embedding that generates cos/sin from position ids.

Replaces HuggingFace's Qwen2RotaryEmbedding while preserving full rope_scaling support. When config.rope_scaling is set (e.g. linear, dynamic-NTK, YaRN, …), we delegate the inv_freq / attention_scaling computation to HF's ROPE_INIT_FUNCTIONS so that the frequency basis and scaling factor are identical to the original checkpoint.

For Qwen2.5-VL-style multimodal RoPE (rope_scaling.rope_type == "mrope") the inv_freq basis is the standard default-rope one; the difference is that position ids are 3-D (t, h, w) per token and the mrope_section describes how the head dimension is split across axes. This module accepts either 2-D scalar position ids (B, S) or 3-D multimodal position ids (B, 3, S) and dispatches accordingly so the same module works for both BAGEL (1-D rope) and Lance (Qwen2.5-VL mrope). This module has no learnable parameters.

attention_scaling instance-attribute

attention_scaling = 1.0

forward

forward(
    x: Tensor, position_ids: Tensor
) -> tuple[Tensor, Tensor]

Generate cos/sin embeddings for given position ids.

Parameters:

Name Type Description Default
x Tensor

Input tensor (only used for dtype inference).

required
position_ids Tensor

Either 2-D scalar (batch_size, seq_len) for plain 1-D RoPE, or 3-D multimodal (batch_size, 3, seq_len) for Qwen2.5-VL-style mRoPE. The latter is auto-detected from position_ids.ndim.

required

Returns:

Type Description
tuple[Tensor, Tensor]

cos, sin: Rotary embeddings, each of shape (batch_size, seq_len, dim).

BaseNavitOutputWithPast dataclass

Bases: ModelOutput

packed_query_sequence class-attribute instance-attribute

packed_query_sequence: FloatTensor = None

past_key_values class-attribute instance-attribute

past_key_values: NaiveCache | None = None

MLPconnector

Bases: Module

act instance-attribute

act = GELU()

fc1 instance-attribute

fc1 = ColumnParallelLinear(
    input_dim,
    output_dim,
    bias=True,
    gather_output=False,
    quant_config=quant_config,
    prefix=f"{prefix}.fc1",
)

fc2 instance-attribute

fc2 = RowParallelLinear(
    output_dim,
    output_dim,
    bias=True,
    input_is_parallel=True,
    quant_config=quant_config,
    prefix=f"{prefix}.fc2",
)

forward

forward(x)

NaiveCache

key_cache instance-attribute

key_cache = {k: None for k in (range(num_layers))}

num_layers property

num_layers

seq_lens property

seq_lens

value_cache instance-attribute

value_cache = {k: None for k in (range(num_layers))}

PackedAttentionMoT

Bases: Module

Packed attention with Mixture-of-Tokens routing for understanding/generation.

Uses vLLM's QKVParallelLinear and RowParallelLinear for tensor parallelism support, following the same pattern as vLLM's Qwen2Attention.

The q/k/v projections are stacked into a single QKVParallelLinear: - qkv_proj : stacks q_proj + k_proj + v_proj (understanding + gen text) - qkv_proj_moe_gen : stacks q_proj_moe_gen + k_proj_moe_gen + v_proj_moe_gen (gen vae)

attn_causal instance-attribute

attn_causal = Attention(
    num_heads=total_num_heads,
    head_size=head_dim,
    softmax_scale=1.0 / head_dim**0.5,
    causal=True,
    num_kv_heads=total_num_kv_heads,
)

attn_noncausal instance-attribute

attn_noncausal = Attention(
    num_heads=total_num_heads,
    head_size=head_dim,
    softmax_scale=1.0 / head_dim**0.5,
    causal=False,
    num_kv_heads=total_num_kv_heads,
)

head_dim instance-attribute

head_dim = hidden_size // total_num_heads

hidden_size instance-attribute

hidden_size = hidden_size

k_norm instance-attribute

k_norm = RMSNorm(head_dim, eps=rms_norm_eps)

k_norm_moe_gen instance-attribute

k_norm_moe_gen = RMSNorm(head_dim, eps=rms_norm_eps)

kv_size instance-attribute

kv_size = num_kv_heads * head_dim

layer_idx instance-attribute

layer_idx = layer_idx

num_heads instance-attribute

num_heads = total_num_heads // tp_size

num_kv_heads instance-attribute

num_kv_heads = max(1, total_num_kv_heads // tp_size)

o_proj instance-attribute

o_proj = RowParallelLinear(
    total_num_heads * head_dim,
    hidden_size,
    bias=False,
    quant_config=quant_config,
    prefix=f"{prefix}.o_proj",
)

o_proj_moe_gen instance-attribute

o_proj_moe_gen = RowParallelLinear(
    total_num_heads * head_dim,
    hidden_size,
    bias=False,
    quant_config=quant_config,
    prefix=f"{prefix}.o_proj_moe_gen",
)

parallel_config instance-attribute

parallel_config = parallel_config

q_norm instance-attribute

q_norm = RMSNorm(head_dim, eps=rms_norm_eps)

q_norm_moe_gen instance-attribute

q_norm_moe_gen = RMSNorm(head_dim, eps=rms_norm_eps)

q_size instance-attribute

q_size = num_heads * head_dim

qkv_proj instance-attribute

qkv_proj = QKVParallelLinear(
    hidden_size,
    head_dim,
    total_num_heads,
    total_num_kv_heads,
    bias=True,
    quant_config=quant_config,
    prefix=f"{prefix}.qkv_proj",
)

qkv_proj_moe_gen instance-attribute

qkv_proj_moe_gen = QKVParallelLinear(
    hidden_size,
    head_dim,
    total_num_heads,
    total_num_kv_heads,
    bias=True,
    quant_config=quant_config,
    prefix=f"{prefix}.qkv_proj_moe_gen",
)

rotary_op instance-attribute

rotary_op = RotaryEmbedding(is_neox_style=True)

total_num_heads instance-attribute

total_num_heads = num_attention_heads

total_num_kv_heads instance-attribute

total_num_kv_heads = num_key_value_heads

forward

forward(
    packed_query_sequence: Tensor,
    query_lens: Tensor,
    packed_query_position_embeddings: Tensor,
    past_key_values: NaiveCache | None = None,
    update_past_key_values=True,
    is_causal=True,
    mode="und",
    packed_vae_token_indexes=None,
    packed_text_indexes=None,
)

PositionEmbedding

Bases: Module

hidden_size instance-attribute

hidden_size = hidden_size

max_num_patch_per_side instance-attribute

max_num_patch_per_side = max_num_patch_per_side

pos_embed instance-attribute

pos_embed = Parameter(
    zeros(max_num_patch_per_side**2, hidden_size),
    requires_grad=False,
)

forward

forward(position_ids)

Qwen2MoTConfig

Bases: Qwen2Config

Configuration for Qwen2MoT (Mixture of Tokens) model.

This is fundamentally different from Qwen2, hence the distinct name.

keys_to_ignore_at_inference class-attribute instance-attribute

keys_to_ignore_at_inference = ['past_key_values']

layer_module instance-attribute

layer_module = layer_module

model_type class-attribute instance-attribute

model_type = 'qwen2_mot'

qk_norm instance-attribute

qk_norm = qk_norm

Qwen2MoTDecoderLayer

Bases: Module

hidden_size instance-attribute

hidden_size = hidden_size

input_layernorm instance-attribute

input_layernorm = RMSNorm(hidden_size, eps=rms_norm_eps)

input_layernorm_moe_gen instance-attribute

input_layernorm_moe_gen = RMSNorm(
    hidden_size, eps=rms_norm_eps
)

mlp instance-attribute

mlp = BagelMLP(
    hidden_size,
    intermediate_size,
    hidden_act,
    quant_config=quant_config,
    prefix=f"{prefix}.mlp",
)

mlp_moe_gen instance-attribute

mlp_moe_gen = BagelMLP(
    hidden_size,
    intermediate_size,
    hidden_act,
    quant_config=quant_config,
    prefix=f"{prefix}.mlp_moe_gen",
)

post_attention_layernorm instance-attribute

post_attention_layernorm = RMSNorm(
    hidden_size, eps=rms_norm_eps
)

post_attention_layernorm_moe_gen instance-attribute

post_attention_layernorm_moe_gen = RMSNorm(
    hidden_size, eps=rms_norm_eps
)

self_attn instance-attribute

self_attn = attn_module(
    config,
    layer_idx,
    parallel_config=parallel_config,
    quant_config=quant_config,
    prefix=f"{prefix}.self_attn",
)

forward

forward(
    hidden_states: Tensor,
    encoder_hidden_states: Tensor | None = None,
    packed_query_sequence: Tensor | None = None,
    query_lens: Tensor = None,
    packed_query_position_embeddings: Tensor = None,
    past_key_values: NaiveCache | None = None,
    update_past_key_values=True,
    is_causal=True,
    mode="und",
    packed_vae_token_indexes=None,
    packed_text_indexes=None,
) -> BaseNavitOutputWithPast

Qwen2MoTForCausalLM

Bases: Qwen2PreTrainedModel

lm_head instance-attribute

lm_head = Linear(hidden_size, vocab_size, bias=False)

model instance-attribute

model = Qwen2MoTModel(
    config,
    parallel_config=parallel_config,
    quant_config=quant_config,
    prefix=f"{prefix}.model",
)

vocab_size instance-attribute

vocab_size = vocab_size

forward

forward(
    packed_query_sequence: Tensor | None = None,
    query_lens: Tensor | None = None,
    packed_query_position_ids: Tensor | None = None,
    past_key_values: NaiveCache | None = None,
    update_past_key_values=True,
    is_causal=True,
    mode="und",
    packed_vae_token_indexes=None,
    packed_text_indexes=None,
    packed_text_ids: Tensor | None = None,
    return_embeddings_only: bool = False,
) -> BaseNavitOutputWithPast

get_decoder

get_decoder()

get_input_embeddings

get_input_embeddings()

get_output_embeddings

get_output_embeddings()

load_weights

load_weights(
    weights: Iterable[tuple[str, Tensor]],
) -> set[str]

Load weights for vLLM parallel layers.

Handles stacked parameter remapping for QKVParallelLinear
  • q_proj, k_proj, v_proj -> qkv_proj (shard ids: q, k, v)
  • q_proj_moe_gen, k_proj_moe_gen, v_proj_moe_gen -> qkv_proj_moe_gen

Other parallel layers (gate_proj, up_proj, down_proj, embed_tokens, etc.) keep HF checkpoint names and use weight_loader for TP sharding.

set_decoder

set_decoder(decoder)

set_input_embeddings

set_input_embeddings(value)

set_output_embeddings

set_output_embeddings(new_embeddings)

Qwen2MoTModel

Bases: Qwen2PreTrainedModel

embed_tokens instance-attribute

embed_tokens = VocabParallelEmbedding(
    vocab_size, hidden_size
)

layers instance-attribute

layers = ModuleList(
    [
        (
            Qwen2MoTDecoderLayer(
                config,
                layer_idx,
                attn_module=PackedAttentionMoT,
                parallel_config=parallel_config,
                quant_config=quant_config,
                prefix=f"{prefix}.layers.{layer_idx}",
            )
        )
        for layer_idx in (range(num_hidden_layers))
    ]
)

norm instance-attribute

norm = RMSNorm(hidden_size, eps=rms_norm_eps)

norm_moe_gen instance-attribute

norm_moe_gen = RMSNorm(hidden_size, eps=rms_norm_eps)

padding_idx instance-attribute

padding_idx = pad_token_id

rotary_emb instance-attribute

rotary_emb = BagelRotaryEmbedding(config=config)

use_moe instance-attribute

use_moe = 'Mo' in layer_module

vocab_size instance-attribute

vocab_size = vocab_size

forward

forward(
    packed_query_sequence: Tensor | None = None,
    query_lens: Tensor | None = None,
    packed_query_position_ids: Tensor | None = None,
    past_key_values: NaiveCache | None = None,
    update_past_key_values=True,
    is_causal=True,
    mode="und",
    packed_vae_token_indexes=None,
    packed_text_indexes=None,
    packed_text_ids: Tensor | None = None,
    return_embeddings_only: bool = False,
) -> BaseNavitOutputWithPast

TimestepEmbedder

Bases: Module

Embeds scalar timesteps into vector representations.

frequency_embedding_size instance-attribute

frequency_embedding_size = frequency_embedding_size

mlp instance-attribute

mlp = Sequential(
    Linear(
        frequency_embedding_size, hidden_size, bias=True
    ),
    SiLU(),
    Linear(hidden_size, hidden_size, bias=True),
)

forward

forward(t)

get_1d_sincos_pos_embed_from_grid

get_1d_sincos_pos_embed_from_grid(embed_dim, pos)

embed_dim: output dimension for each position pos: a list of positions to be encoded: size (M,) out: (M, D)

get_2d_sincos_pos_embed

get_2d_sincos_pos_embed(
    embed_dim, grid_size, cls_token=False, extra_tokens=0
)

get_2d_sincos_pos_embed_from_grid

get_2d_sincos_pos_embed_from_grid(embed_dim, grid)

get_flattened_position_ids_extrapolate

get_flattened_position_ids_extrapolate(
    img_h, img_w, patch_size, max_num_patches_per_side
)

patchify

patchify(imgs, p)

imgs: (N, 3, H, W) or (3, H, W) x: (N, L, patch_size2 *3) or (L, patch_size2 *3)