Skip to content

vllm_omni.diffusion.models.lance.lance_transformer

Lance transformer pieces.

The Lance LLM is BAGEL's Qwen2-MoT transformer verbatim — the released Lance_3B checkpoint uses the identical *_moe_gen / q_norm / vae2llm / llm2vae / time_embedder / latent_pos_embed layout, so Bagel / Qwen2MoTForCausalLM / Qwen2MoTConfig / NaiveCache are re-exported unchanged. Only the understanding ViT differs (Qwen2.5-VL vision instead of SigLIP) and the video path adds a 3D latent position embedding.

This module also provides :class:LanceBagel, a thin :class:Bagel subclass that overrides the two ViT entry points to consume Qwen2.5-VL's packed pixel_values + image_grid_thw layout directly (BAGEL itself assumes SigLIP-style (C, H, W) tensors that get patchified inside the model).

NOTE ON mRoPE: Lance's backbone is Qwen2.5-VL and its understanding / video paths use multimodal RoPE (mrope_section=[16,24,24]). BAGEL's BagelRotaryEmbedding is plain 1-D RoPE on scalar position ids. For the text2img generation path Lance assigns scalar positions (the gen latent block shares a single rope position, same as BAGEL), so the reused rotary is correct there. Full mRoPE for the x2t / video understanding path is a follow-up — see LancePositionEmbedding3D.

LANCE_SECONDS_PER_GRID module-attribute

LANCE_SECONDS_PER_GRID = 1.0

LANCE_TOKENS_PER_SECOND module-attribute

LANCE_TOKENS_PER_SECOND = 2

LANCE_VIDEO_BUCKET_STRIDE module-attribute

LANCE_VIDEO_BUCKET_STRIDE = 16

LANCE_VIDEO_MAX_DURATION module-attribute

LANCE_VIDEO_MAX_DURATION = 6.0

LANCE_VIDEO_SAMPLE_FPS module-attribute

LANCE_VIDEO_SAMPLE_FPS = 12

LANCE_VIDEO_TEMPORAL_STRIDE module-attribute

LANCE_VIDEO_TEMPORAL_STRIDE = 4

LANCE_VIDEO_VAE_DIVISIBLE_CROP module-attribute

LANCE_VIDEO_VAE_DIVISIBLE_CROP = 16

LANCE_VIDEO_VAE_RESOLUTION module-attribute

LANCE_VIDEO_VAE_RESOLUTION = 640

LANCE_VIDEO_VIT_DIVISIBLE_CROP module-attribute

LANCE_VIDEO_VIT_DIVISIBLE_CROP = 28

LANCE_VIDEO_VIT_RESOLUTION module-attribute

LANCE_VIDEO_VIT_RESOLUTION = 616

LANCE_VIT_ASPECT_RATIOS module-attribute

LANCE_VIT_ASPECT_RATIOS = (
    "21:9",
    "16:9",
    "4:3",
    "1:1",
    "3:4",
    "9:16",
)

LANCE_VIT_BUCKET_RESOLUTION module-attribute

LANCE_VIT_BUCKET_RESOLUTION = 672

LANCE_VIT_BUCKET_STRIDE module-attribute

LANCE_VIT_BUCKET_STRIDE = 16

LANCE_VIT_DIVISIBLE_CROP module-attribute

LANCE_VIT_DIVISIBLE_CROP = 28

LANCE_VIT_NORM_MEAN module-attribute

LANCE_VIT_NORM_MEAN = (0.48145466, 0.4578275, 0.40821073)

LANCE_VIT_NORM_STD module-attribute

LANCE_VIT_NORM_STD = (0.26862954, 0.26130258, 0.27577711)

LANCE_VIT_PATCH_SIZE module-attribute

LANCE_VIT_PATCH_SIZE = 14

LANCE_VIT_SPATIAL_MERGE module-attribute

LANCE_VIT_SPATIAL_MERGE = 2

LANCE_VIT_TEMPORAL_PATCH_SIZE module-attribute

LANCE_VIT_TEMPORAL_PATCH_SIZE = 2

Bagel

Bases: Module

base_model_prefix class-attribute instance-attribute

base_model_prefix = 'bagel'

config instance-attribute

config = config

config_class class-attribute instance-attribute

config_class = BagelConfig

connector instance-attribute

connector = MLPconnector(
    vit_hidden_size,
    hidden_size,
    connector_act,
    quant_config=quant_config,
    prefix=f"{prefix}.connector",
)

get_flattened_position_ids instance-attribute

get_flattened_position_ids = (
    get_flattened_position_ids_extrapolate
)

hidden_size instance-attribute

hidden_size = hidden_size

language_model instance-attribute

language_model = language_model

latent_channel instance-attribute

latent_channel = z_channels

latent_downsample instance-attribute

latent_downsample = downsample * latent_patch_size

latent_patch_size instance-attribute

latent_patch_size = latent_patch_size

latent_pos_embed instance-attribute

latent_pos_embed = PositionEmbedding(
    max_latent_size, hidden_size
)

llm2vae instance-attribute

llm2vae = Linear(hidden_size, patch_latent_dim)

max_latent_size instance-attribute

max_latent_size = max_latent_size

num_heads instance-attribute

num_heads = num_attention_heads

parallel_config instance-attribute

parallel_config = parallel_config

patch_latent_dim instance-attribute

patch_latent_dim = latent_patch_size ** 2 * latent_channel

time_embedder instance-attribute

time_embedder = TimestepEmbedder(hidden_size)

timestep_shift instance-attribute

timestep_shift = timestep_shift

use_moe instance-attribute

use_moe = 'Mo' in layer_module

vae2llm instance-attribute

vae2llm = Linear(patch_latent_dim, hidden_size)

vit_hidden_size instance-attribute

vit_hidden_size = hidden_size

vit_max_num_patch_per_side instance-attribute

vit_max_num_patch_per_side = vit_max_num_patch_per_side

vit_model instance-attribute

vit_model = vit_model

vit_patch_size instance-attribute

vit_patch_size = patch_size

vit_pos_embed instance-attribute

vit_pos_embed = PositionEmbedding(
    vit_max_num_patch_per_side, hidden_size
)

forward

forward(
    x_t: Tensor,
    timestep: LongTensor,
    packed_vae_token_indexes: LongTensor,
    packed_vae_position_ids: LongTensor,
    packed_text_ids: LongTensor,
    packed_text_indexes: LongTensor,
    packed_position_ids: LongTensor,
    packed_seqlens: IntTensor,
    past_key_values: NaiveCache,
    cfg_renorm_min: float = 0.0,
    cfg_renorm_type: str = "global",
    cfg_text_scale: float = 1.0,
    cfg_img_scale: float = 1.0,
    cfg_branches: dict | None = None,
)

forward_cache_update_text

forward_cache_update_text(
    past_key_values: NaiveCache,
    packed_text_ids: IntTensor,
    packed_text_position_ids: LongTensor,
    text_token_lens: LongTensor,
)

forward_cache_update_vae

forward_cache_update_vae(
    vae_model,
    past_key_values: NaiveCache,
    padded_images: Tensor,
    patchified_vae_latent_shapes: list,
    packed_vae_position_ids: LongTensor,
    packed_timesteps: Tensor,
    packed_vae_token_indexes: LongTensor,
    packed_text_ids: LongTensor,
    packed_text_indexes: LongTensor,
    packed_position_ids: LongTensor,
    packed_seqlens: IntTensor,
)

forward_cache_update_vit

forward_cache_update_vit(
    past_key_values: NaiveCache,
    packed_text_ids: LongTensor,
    packed_text_indexes: LongTensor,
    packed_vit_tokens: Tensor,
    packed_vit_token_indexes: LongTensor,
    packed_vit_position_ids: LongTensor,
    vit_token_seqlens: IntTensor,
    packed_position_ids: LongTensor,
    packed_seqlens: IntTensor,
)

forward_single_branch

forward_single_branch(
    x_t: Tensor,
    timestep: LongTensor,
    packed_vae_token_indexes: LongTensor,
    packed_vae_position_ids: LongTensor,
    packed_text_ids: LongTensor,
    packed_text_indexes: LongTensor,
    packed_position_ids: LongTensor,
    packed_seqlens: IntTensor,
    past_key_values: NaiveCache,
) -> Tensor

Run a single-branch forward pass (no CFG batching).

Used by CFG parallel mode where each rank computes one branch. Returns the velocity v_t for the given branch. Supports Ulysses / Ring SP when parallel_config.sequence_parallel_size > 1.

generate_image

generate_image(
    packed_text_ids: LongTensor,
    packed_text_indexes: LongTensor,
    packed_init_noises: Tensor,
    packed_vae_position_ids: LongTensor,
    packed_vae_token_indexes: LongTensor,
    packed_seqlens: IntTensor,
    packed_position_ids: LongTensor,
    past_key_values: NaiveCache,
    num_timesteps: int = 24,
    timestep_shift: float = 1.0,
    cfg_renorm_min: float = 0.0,
    cfg_renorm_type: str = "global",
    cfg_interval: tuple[float, float] = [0, 1],
    cfg_text_scale: float = 1.0,
    cfg_text_packed_position_ids: LongTensor | None = None,
    cfg_text_past_key_values: NaiveCache | None = None,
    cfg_img_scale: float = 1.0,
    cfg_img_packed_position_ids: LongTensor | None = None,
    cfg_img_past_key_values: NaiveCache | None = None,
    return_trajectory_latents: bool = False,
    scheduler: object | None = None,
    scheduler_kwargs: dict | None = None,
    frame_condition_token_indexes: LongTensor | None = None,
)

generate_text

generate_text(
    past_key_values: NaiveCache,
    packed_start_tokens: LongTensor,
    packed_query_position_ids: LongTensor,
    max_length: int,
    do_sample: bool = False,
    temperature: float = 1.0,
    end_token_id: int | None = None,
)

Autoregressive text generation (ported from original BAGEL).

Decodes tokens one at a time, appending to past_key_values until max_length is reached or end_token_id is generated.

prepare_input

prepare_input(
    curr_kvlens, curr_rope, image_sizes, new_token_ids=None
)

prepare_prompts

prepare_prompts(
    curr_kvlens,
    curr_rope,
    prompts,
    tokenizer,
    new_token_ids,
)

prepare_start_tokens

prepare_start_tokens(curr_kvlens, curr_rope, new_token_ids)

Prepare start tokens for autoregressive text generation.

Ported from the original BAGEL Bagel.prepare_start_tokens.

prepare_vae_images

prepare_vae_images(
    curr_kvlens,
    curr_rope,
    images,
    transforms,
    new_token_ids,
    timestep=0,
)

prepare_vae_latent

prepare_vae_latent(
    curr_kvlens, curr_rope, image_sizes, new_token_ids
)

prepare_vae_latent_cfg

prepare_vae_latent_cfg(curr_kvlens, curr_rope, image_sizes)

prepare_vit_images

prepare_vit_images(
    curr_kvlens,
    curr_rope,
    images,
    transforms,
    new_token_ids,
)

BaseNavitOutputWithPast dataclass

Bases: ModelOutput

packed_query_sequence class-attribute instance-attribute

packed_query_sequence: FloatTensor = None

past_key_values class-attribute instance-attribute

past_key_values: NaiveCache | None = None

LanceBagel

Bases: Bagel

Bagel subclass with Lance-specific ViT handling.

The released Lance checkpoint pairs BAGEL's Qwen2-MoT trunk with the Qwen2.5-VL vision tower (whose merger already projects to LLM hidden_size and which carries its own rotary positional encoding). Two BAGEL assumptions therefore break for Lance:

  1. prepare_vit_images calls transforms(image) -> (C, H, W) and then patchify(...); Lance's image processor returns (num_patches_flat, patch_features) + image_grid_thw already.
  2. forward_cache_update_vit adds a connector projection plus a 2-D vit_pos_embed that has no checkpoint weights and would also double-count the ViT's own positional encoding.

This subclass overrides exactly those two methods. Everything else — LLM trunk, VAE flow, generation loop, forward_cache_update_text / forward_cache_update_vae — is reused unchanged.

forward_cache_update_vae

forward_cache_update_vae(
    vae_model,
    past_key_values,
    padded_images=None,
    patchified_vae_latent_shapes=None,
    packed_vae_position_ids=None,
    packed_timesteps=None,
    packed_vae_token_indexes=None,
    packed_text_ids=None,
    packed_text_indexes=None,
    packed_position_ids=None,
    packed_seqlens=None,
    packed_indexes=None,
    key_values_lens=None,
    packed_key_value_indexes=None,
    precomputed_latent=None,
)

Lance-native VAE prefill that actually scatters the encoded latents into the LLM query sequence.

:meth:Bagel.forward_cache_update_vae in vllm-omni computes packed_latent = vae2llm(...) + time_embed + pos_embed and then passes only packed_text_ids to the LLM — the VAE embeddings never enter the query sequence (the LLM builds it from embed_tokens(packed_text_ids) which is just the 2 framing tokens). The mismatch between query_lens = [num_vae + 2] and the resulting 2-token sequence is what crashes the gather inside attention.

We scatter both pieces explicitly: text framing tokens at packed_text_indexes and the VAE latent embeddings at packed_vae_token_indexes, producing a full-length (sum(packed_seqlens), hidden) sequence the LLM can attend over. Empty prep data (legacy x2t / x2t_video no-op path) is short-circuited.

forward_cache_update_vit

forward_cache_update_vit(*args, **kwargs)

prepare_prompts

prepare_prompts(*args, **kwargs)

prepare_start_tokens

prepare_start_tokens(*args, **kwargs)

prepare_vae_images

prepare_vae_images(
    curr_kvlens,
    curr_rope,
    images,
    transforms,
    new_token_ids,
    timestep=0,
)

VAE prefill router.

  • When images is non-empty (image_edit / video_edit path on the image side): delegate to :meth:_lance_native_prepare_vae_images which emits Lance's 3-D mRoPE positions and a real VAE prefill, letting BAGEL's parent image_edit flow handle the rest.
  • When images is empty (t2i / x2t paths): short-circuit with the no-op output, mirroring BAGEL's "no image to prefill" sentinel.

prepare_vae_latent

prepare_vae_latent(
    curr_kvlens, curr_rope, image_sizes, new_token_ids
)

prepare_vae_latent_cfg

prepare_vae_latent_cfg(curr_kvlens, curr_rope, image_sizes)

prepare_video_latent

prepare_video_latent(
    curr_kvlens, curr_rope, video_shapes, new_token_ids
)

3-D analogue of :meth:prepare_vae_latent.

video_shapes is a list of (T, H, W) per request (RGB pixel space). We package one packed-init-noise tensor over T_lat × H_lat × W_lat latent tokens per video, plus 1-D indices into the 3-D position embedding table maintained by :class:LancePositionEmbedding3D (bagel.latent_pos_embed). Latent geometry:

  • spatial: H_lat = H // latent_downsample (=16 for Lance)
  • temporal: T_lat = (T - 1) // downsample_temporal + 1 (=4 for Wan2.2)
  • channels: latent_channel = 48

Position ids are flattened t * max_per_side² + h * max_per_side + w so they index directly into the (max_num_frames * max_per_side², hidden_size) table.

prepare_video_latent_cfg

prepare_video_latent_cfg(
    curr_kvlens, curr_rope, video_shapes
)

3-D analogue of :meth:prepare_vae_latent_cfg (CFG side).

Mirrors :meth:prepare_video_latent's mRoPE 3-D position layout EXACTLY, including the LANCE_TOKENS_PER_SECOND * LANCE_SECONDS_PER_GRID temporal scaling. Without that scaling, the cfg_text branch attends with different rope coordinates than the cond branch (and than upstream's get_rope_index), which makes cfg_text_v_t diverge from upstream and the CFG combination amplifies the error every denoise step.

prepare_vit_images

prepare_vit_images(
    curr_kvlens,
    curr_rope,
    images,
    transforms,
    new_token_ids,
)

prepare_vit_videos

prepare_vit_videos(
    curr_kvlens,
    curr_rope,
    videos,
    new_token_ids,
    precomputed_vit=None,
)

Multi-frame ViT prefill for the x2t_video / video_edit paths.

videos is a list of per-request video tensors / numpy arrays of shape (T, H, W, 3). By default the Qwen2-VL video processor is used to convert each video to (pixel_values_videos, video_grid_thw). For video_edit precision matching, the pipeline may pre-compute the upstream-style BucketResize output and pass it via precomputed_vit — a list of (pixel_values, grid_thw) per video, in which case the processor call is skipped.

LanceIdentityConnector

Bases: Module

No-op connector for Lance.

BAGEL's connector projects the ViT hidden size to the LLM hidden size. Qwen2.5-VL's vision tower (which Lance uses) already projects to the LLM hidden size internally via merger (out_hidden_size = hidden_size), and the released Lance safetensors carry no connector.* weights. We therefore plug in an Identity connector so forward_cache_update_vit keeps its existing call site without a separate code path.

forward

forward(x: Tensor) -> Tensor

LancePositionEmbedding3D

Bases: Module

Frozen 3-D sin-cos latent position embedding for the video path.

BAGEL only ships a 2-D PositionEmbedding (image latents). Lance's Lance_3B_Video checkpoint adds a temporal axis; this mirrors upstream modeling/lance/modeling_utils.py::PositionEmbedding3D. The image path uses t=1 and is numerically equivalent to the 2-D embedding.

hidden_size instance-attribute

hidden_size = hidden_size

max_num_frames instance-attribute

max_num_frames = max_num_frames

max_num_patch_per_side instance-attribute

max_num_patch_per_side = max_num_patch_per_side

pos_embed instance-attribute

pos_embed = Parameter(
    zeros(n, hidden_size), requires_grad=False
)

forward

forward(position_ids: Tensor) -> Tensor

LanceQwen2_5_VLNaViTWrapper

Bases: Module

Packed (NaViT-style) wrapper around the Qwen2.5-VL vision tower.

Bridges BAGEL's vit(packed_pixel_values, ...) -> [num_tokens, vit_hidden] surface to the HF Qwen2_5_VisionTransformerPretrainedModel which consumes (hidden_states, grid_thw). The packed call additionally needs a per-image image_grid_thw so non-square images (and the spatial-merge token count) line up — :class:LanceBagel stashes the grid on the wrapper before invoking the ViT.

config property

config

spatial_merge_size instance-attribute

spatial_merge_size = spatial_merge_size

vision_model instance-attribute

vision_model = getattr(vision_model, "visual", vision_model)

forward

forward(
    packed_pixel_values: Tensor,
    packed_flattened_position_ids: Tensor,
    cu_seqlens: Tensor,
    max_seqlen: int,
) -> Tensor

set_pending_grid_thw

set_pending_grid_thw(grid_thw: Tensor) -> None

LanceZeroVitPosEmbed

Bases: Module

No-op positional embedding for Lance's ViT tokens.

BAGEL adds an extra 2-D sin-cos vit_pos_embed on top of the ViT output. Qwen2.5-VL's vision tower already carries its own (rotary) positional encoding, and the released Lance safetensors carry no vit_pos_embed.* weights. This module returns a broadcast-friendly zero so the addition in forward_cache_update_vit is a no-op without requiring a code-path branch.

forward

forward(position_ids: Tensor) -> Tensor

MLPconnector

Bases: Module

act instance-attribute

act = GELU()

fc1 instance-attribute

fc1 = ColumnParallelLinear(
    input_dim,
    output_dim,
    bias=True,
    gather_output=False,
    quant_config=quant_config,
    prefix=f"{prefix}.fc1",
)

fc2 instance-attribute

fc2 = RowParallelLinear(
    output_dim,
    output_dim,
    bias=True,
    input_is_parallel=True,
    quant_config=quant_config,
    prefix=f"{prefix}.fc2",
)

forward

forward(x)

NaiveCache

key_cache instance-attribute

key_cache = {k: None for k in (range(num_layers))}

num_layers property

num_layers

seq_lens property

seq_lens

value_cache instance-attribute

value_cache = {k: None for k in (range(num_layers))}

PackedAttentionMoT

Bases: Module

Packed attention with Mixture-of-Tokens routing for understanding/generation.

Uses vLLM's QKVParallelLinear and RowParallelLinear for tensor parallelism support, following the same pattern as vLLM's Qwen2Attention.

The q/k/v projections are stacked into a single QKVParallelLinear: - qkv_proj : stacks q_proj + k_proj + v_proj (understanding + gen text) - qkv_proj_moe_gen : stacks q_proj_moe_gen + k_proj_moe_gen + v_proj_moe_gen (gen vae)

attn_causal instance-attribute

attn_causal = Attention(
    num_heads=total_num_heads,
    head_size=head_dim,
    softmax_scale=1.0 / head_dim**0.5,
    causal=True,
    num_kv_heads=total_num_kv_heads,
)

attn_noncausal instance-attribute

attn_noncausal = Attention(
    num_heads=total_num_heads,
    head_size=head_dim,
    softmax_scale=1.0 / head_dim**0.5,
    causal=False,
    num_kv_heads=total_num_kv_heads,
)

head_dim instance-attribute

head_dim = hidden_size // total_num_heads

hidden_size instance-attribute

hidden_size = hidden_size

k_norm instance-attribute

k_norm = RMSNorm(head_dim, eps=rms_norm_eps)

k_norm_moe_gen instance-attribute

k_norm_moe_gen = RMSNorm(head_dim, eps=rms_norm_eps)

kv_size instance-attribute

kv_size = num_kv_heads * head_dim

layer_idx instance-attribute

layer_idx = layer_idx

num_heads instance-attribute

num_heads = total_num_heads // tp_size

num_kv_heads instance-attribute

num_kv_heads = max(1, total_num_kv_heads // tp_size)

o_proj instance-attribute

o_proj = RowParallelLinear(
    total_num_heads * head_dim,
    hidden_size,
    bias=False,
    quant_config=quant_config,
    prefix=f"{prefix}.o_proj",
)

o_proj_moe_gen instance-attribute

o_proj_moe_gen = RowParallelLinear(
    total_num_heads * head_dim,
    hidden_size,
    bias=False,
    quant_config=quant_config,
    prefix=f"{prefix}.o_proj_moe_gen",
)

parallel_config instance-attribute

parallel_config = parallel_config

q_norm instance-attribute

q_norm = RMSNorm(head_dim, eps=rms_norm_eps)

q_norm_moe_gen instance-attribute

q_norm_moe_gen = RMSNorm(head_dim, eps=rms_norm_eps)

q_size instance-attribute

q_size = num_heads * head_dim

qkv_proj instance-attribute

qkv_proj = QKVParallelLinear(
    hidden_size,
    head_dim,
    total_num_heads,
    total_num_kv_heads,
    bias=True,
    quant_config=quant_config,
    prefix=f"{prefix}.qkv_proj",
)

qkv_proj_moe_gen instance-attribute

qkv_proj_moe_gen = QKVParallelLinear(
    hidden_size,
    head_dim,
    total_num_heads,
    total_num_kv_heads,
    bias=True,
    quant_config=quant_config,
    prefix=f"{prefix}.qkv_proj_moe_gen",
)

rotary_op instance-attribute

rotary_op = RotaryEmbedding(is_neox_style=True)

total_num_heads instance-attribute

total_num_heads = num_attention_heads

total_num_kv_heads instance-attribute

total_num_kv_heads = num_key_value_heads

forward

forward(
    packed_query_sequence: Tensor,
    query_lens: Tensor,
    packed_query_position_embeddings: Tensor,
    past_key_values: NaiveCache | None = None,
    update_past_key_values=True,
    is_causal=True,
    mode="und",
    packed_vae_token_indexes=None,
    packed_text_indexes=None,
)

PositionEmbedding

Bases: Module

hidden_size instance-attribute

hidden_size = hidden_size

max_num_patch_per_side instance-attribute

max_num_patch_per_side = max_num_patch_per_side

pos_embed instance-attribute

pos_embed = Parameter(
    zeros(max_num_patch_per_side**2, hidden_size),
    requires_grad=False,
)

forward

forward(position_ids)

Qwen2MoTConfig

Bases: Qwen2Config

Configuration for Qwen2MoT (Mixture of Tokens) model.

This is fundamentally different from Qwen2, hence the distinct name.

keys_to_ignore_at_inference class-attribute instance-attribute

keys_to_ignore_at_inference = ['past_key_values']

layer_module instance-attribute

layer_module = layer_module

model_type class-attribute instance-attribute

model_type = 'qwen2_mot'

qk_norm instance-attribute

qk_norm = qk_norm

Qwen2MoTDecoderLayer

Bases: Module

hidden_size instance-attribute

hidden_size = hidden_size

input_layernorm instance-attribute

input_layernorm = RMSNorm(hidden_size, eps=rms_norm_eps)

input_layernorm_moe_gen instance-attribute

input_layernorm_moe_gen = RMSNorm(
    hidden_size, eps=rms_norm_eps
)

mlp instance-attribute

mlp = BagelMLP(
    hidden_size,
    intermediate_size,
    hidden_act,
    quant_config=quant_config,
    prefix=f"{prefix}.mlp",
)

mlp_moe_gen instance-attribute

mlp_moe_gen = BagelMLP(
    hidden_size,
    intermediate_size,
    hidden_act,
    quant_config=quant_config,
    prefix=f"{prefix}.mlp_moe_gen",
)

post_attention_layernorm instance-attribute

post_attention_layernorm = RMSNorm(
    hidden_size, eps=rms_norm_eps
)

post_attention_layernorm_moe_gen instance-attribute

post_attention_layernorm_moe_gen = RMSNorm(
    hidden_size, eps=rms_norm_eps
)

self_attn instance-attribute

self_attn = attn_module(
    config,
    layer_idx,
    parallel_config=parallel_config,
    quant_config=quant_config,
    prefix=f"{prefix}.self_attn",
)

forward

forward(
    hidden_states: Tensor,
    encoder_hidden_states: Tensor | None = None,
    packed_query_sequence: Tensor | None = None,
    query_lens: Tensor = None,
    packed_query_position_embeddings: Tensor = None,
    past_key_values: NaiveCache | None = None,
    update_past_key_values=True,
    is_causal=True,
    mode="und",
    packed_vae_token_indexes=None,
    packed_text_indexes=None,
) -> BaseNavitOutputWithPast

Qwen2MoTForCausalLM

Bases: Qwen2PreTrainedModel

lm_head instance-attribute

lm_head = Linear(hidden_size, vocab_size, bias=False)

model instance-attribute

model = Qwen2MoTModel(
    config,
    parallel_config=parallel_config,
    quant_config=quant_config,
    prefix=f"{prefix}.model",
)

vocab_size instance-attribute

vocab_size = vocab_size

forward

forward(
    packed_query_sequence: Tensor | None = None,
    query_lens: Tensor | None = None,
    packed_query_position_ids: Tensor | None = None,
    past_key_values: NaiveCache | None = None,
    update_past_key_values=True,
    is_causal=True,
    mode="und",
    packed_vae_token_indexes=None,
    packed_text_indexes=None,
    packed_text_ids: Tensor | None = None,
    return_embeddings_only: bool = False,
) -> BaseNavitOutputWithPast

get_decoder

get_decoder()

get_input_embeddings

get_input_embeddings()

get_output_embeddings

get_output_embeddings()

load_weights

load_weights(
    weights: Iterable[tuple[str, Tensor]],
) -> set[str]

Load weights for vLLM parallel layers.

Handles stacked parameter remapping for QKVParallelLinear
  • q_proj, k_proj, v_proj -> qkv_proj (shard ids: q, k, v)
  • q_proj_moe_gen, k_proj_moe_gen, v_proj_moe_gen -> qkv_proj_moe_gen

Other parallel layers (gate_proj, up_proj, down_proj, embed_tokens, etc.) keep HF checkpoint names and use weight_loader for TP sharding.

set_decoder

set_decoder(decoder)

set_input_embeddings

set_input_embeddings(value)

set_output_embeddings

set_output_embeddings(new_embeddings)

Qwen2MoTModel

Bases: Qwen2PreTrainedModel

embed_tokens instance-attribute

embed_tokens = VocabParallelEmbedding(
    vocab_size, hidden_size
)

layers instance-attribute

layers = ModuleList(
    [
        (
            Qwen2MoTDecoderLayer(
                config,
                layer_idx,
                attn_module=PackedAttentionMoT,
                parallel_config=parallel_config,
                quant_config=quant_config,
                prefix=f"{prefix}.layers.{layer_idx}",
            )
        )
        for layer_idx in (range(num_hidden_layers))
    ]
)

norm instance-attribute

norm = RMSNorm(hidden_size, eps=rms_norm_eps)

norm_moe_gen instance-attribute

norm_moe_gen = RMSNorm(hidden_size, eps=rms_norm_eps)

padding_idx instance-attribute

padding_idx = pad_token_id

rotary_emb instance-attribute

rotary_emb = BagelRotaryEmbedding(config=config)

use_moe instance-attribute

use_moe = 'Mo' in layer_module

vocab_size instance-attribute

vocab_size = vocab_size

forward

forward(
    packed_query_sequence: Tensor | None = None,
    query_lens: Tensor | None = None,
    packed_query_position_ids: Tensor | None = None,
    past_key_values: NaiveCache | None = None,
    update_past_key_values=True,
    is_causal=True,
    mode="und",
    packed_vae_token_indexes=None,
    packed_text_indexes=None,
    packed_text_ids: Tensor | None = None,
    return_embeddings_only: bool = False,
) -> BaseNavitOutputWithPast

TimestepEmbedder

Bases: Module

Embeds scalar timesteps into vector representations.

frequency_embedding_size instance-attribute

frequency_embedding_size = frequency_embedding_size

mlp instance-attribute

mlp = Sequential(
    Linear(
        frequency_embedding_size, hidden_size, bias=True
    ),
    SiLU(),
    Linear(hidden_size, hidden_size, bias=True),
)

forward

forward(t)

get_3d_sincos_pos_embed

get_3d_sincos_pos_embed(
    embed_dim: int, t: int, h: int, w: int
) -> ndarray

get_3d_sincos_pos_embed_from_grid

get_3d_sincos_pos_embed_from_grid(
    embed_dim: int, grid: ndarray
) -> ndarray

3-D sin-cos positional embedding (t, h, w), matching the upstream Lance modeling/lance/modeling_utils.py dimension split exactly.

patchify

patchify(imgs, p)

imgs: (N, 3, H, W) or (3, H, W) x: (N, L, patch_size2 *3) or (L, patch_size2 *3)