vllm_omni.diffusion.models.bagel.bagel_transformer ¶
Bagel ¶
Bases: Module
connector instance-attribute ¶
connector = MLPconnector(
self.vit_hidden_size,
self.hidden_size,
config.connector_act,
quant_config=quant_config,
prefix=f"{prefix}.connector",
)
get_flattened_position_ids instance-attribute ¶
get_flattened_position_ids = (
get_flattened_position_ids_extrapolate
)
latent_downsample instance-attribute ¶
latent_pos_embed instance-attribute ¶
latent_pos_embed = PositionEmbedding(
self.max_latent_size, self.hidden_size
)
patch_latent_dim instance-attribute ¶
vit_max_num_patch_per_side instance-attribute ¶
vit_pos_embed instance-attribute ¶
vit_pos_embed = PositionEmbedding(
self.vit_max_num_patch_per_side, self.hidden_size
)
forward ¶
forward(
x_t: Tensor,
timestep: LongTensor,
packed_vae_token_indexes: LongTensor,
packed_vae_position_ids: LongTensor,
packed_text_ids: LongTensor,
packed_text_indexes: LongTensor,
packed_position_ids: LongTensor,
packed_seqlens: IntTensor,
past_key_values: NaiveCache,
cfg_renorm_min: float = 0.0,
cfg_renorm_type: str = "global",
cfg_text_scale: float = 1.0,
cfg_img_scale: float = 1.0,
cfg_branches: dict | None = None,
)
forward_cache_update_text ¶
forward_cache_update_text(
past_key_values: NaiveCache,
packed_text_ids: IntTensor,
packed_text_position_ids: LongTensor,
text_token_lens: LongTensor,
)
forward_cache_update_vae ¶
forward_cache_update_vae(
vae_model,
past_key_values: NaiveCache,
padded_images: Tensor,
patchified_vae_latent_shapes: list,
packed_vae_position_ids: LongTensor,
packed_timesteps: Tensor,
packed_vae_token_indexes: LongTensor,
packed_text_ids: LongTensor,
packed_text_indexes: LongTensor,
packed_position_ids: LongTensor,
packed_seqlens: IntTensor,
)
forward_cache_update_vit ¶
forward_cache_update_vit(
past_key_values: NaiveCache,
packed_text_ids: LongTensor,
packed_text_indexes: LongTensor,
packed_vit_tokens: Tensor,
packed_vit_token_indexes: LongTensor,
packed_vit_position_ids: LongTensor,
vit_token_seqlens: IntTensor,
packed_position_ids: LongTensor,
packed_seqlens: IntTensor,
)
forward_single_branch ¶
forward_single_branch(
x_t: Tensor,
timestep: LongTensor,
packed_vae_token_indexes: LongTensor,
packed_vae_position_ids: LongTensor,
packed_text_ids: LongTensor,
packed_text_indexes: LongTensor,
packed_position_ids: LongTensor,
packed_seqlens: IntTensor,
past_key_values: NaiveCache,
) -> Tensor
Run a single-branch forward pass (no CFG batching).
Used by CFG parallel mode where each rank computes one branch. Returns the velocity v_t for the given branch. Supports Ulysses / Ring SP when parallel_config.sequence_parallel_size > 1.
generate_image ¶
generate_image(
packed_text_ids: LongTensor,
packed_text_indexes: LongTensor,
packed_init_noises: Tensor,
packed_vae_position_ids: LongTensor,
packed_vae_token_indexes: LongTensor,
packed_seqlens: IntTensor,
packed_position_ids: LongTensor,
past_key_values: NaiveCache,
num_timesteps: int = 24,
timestep_shift: float = 1.0,
cfg_renorm_min: float = 0.0,
cfg_renorm_type: str = "global",
cfg_interval: tuple[float, float] = [0, 1],
cfg_text_scale: float = 1.0,
cfg_text_packed_position_ids: LongTensor | None = None,
cfg_text_past_key_values: NaiveCache | None = None,
cfg_img_scale: float = 1.0,
cfg_img_packed_position_ids: LongTensor | None = None,
cfg_img_past_key_values: NaiveCache | None = None,
return_trajectory_latents: bool = False,
scheduler: object | None = None,
scheduler_kwargs: dict | None = None,
frame_condition_token_indexes: LongTensor | None = None,
)
generate_text ¶
generate_text(
past_key_values: NaiveCache,
packed_start_tokens: LongTensor,
packed_query_position_ids: LongTensor,
max_length: int,
do_sample: bool = False,
temperature: float = 1.0,
end_token_id: int | None = None,
)
Autoregressive text generation (ported from original BAGEL).
Decodes tokens one at a time, appending to past_key_values until max_length is reached or end_token_id is generated.
prepare_start_tokens ¶
Prepare start tokens for autoregressive text generation.
Ported from the original BAGEL Bagel.prepare_start_tokens.
prepare_vae_images ¶
prepare_vit_images ¶
BagelMLP ¶
Bases: Module
FFN with Mixture-of-Tokens routing via MoT parallel linear layers.
gate_proj + up_proj are fused into a single MoTMergedColumnParallelLinear. down_proj uses MoTRowParallelLinear. Both layers hold text weights on self and vae weights on self.gen_exp, routing by text_indices / vae_indices.
down_proj instance-attribute ¶
down_proj = RowParallelLinear(
input_size=intermediate_size,
output_size=hidden_size,
bias=False,
input_is_parallel=True,
quant_config=quant_config,
prefix=f"{prefix}.down_proj",
)
gate_up_proj instance-attribute ¶
gate_up_proj = MergedColumnParallelLinear(
input_size=hidden_size,
output_sizes=[intermediate_size, intermediate_size],
bias=False,
gather_output=False,
quant_config=quant_config,
prefix=f"{prefix}.gate_up_proj",
)
BagelRotaryEmbedding ¶
Bases: Module
Standalone rotary embedding that generates cos/sin from position ids.
Replaces HuggingFace's Qwen2RotaryEmbedding while preserving full rope_scaling support. When config.rope_scaling is set (e.g. linear, dynamic-NTK, YaRN, …), we delegate the inv_freq / attention_scaling computation to HF's ROPE_INIT_FUNCTIONS so that the frequency basis and scaling factor are identical to the original checkpoint.
For Qwen2.5-VL-style multimodal RoPE (rope_scaling.rope_type == "mrope") the inv_freq basis is the standard default-rope one; the difference is that position ids are 3-D (t, h, w) per token and the mrope_section describes how the head dimension is split across axes. This module accepts either 2-D scalar position ids (B, S) or 3-D multimodal position ids (B, 3, S) and dispatches accordingly so the same module works for both BAGEL (1-D rope) and Lance (Qwen2.5-VL mrope). This module has no learnable parameters.
forward ¶
forward(
x: Tensor, position_ids: Tensor
) -> tuple[Tensor, Tensor]
Generate cos/sin embeddings for given position ids.
Parameters:
| Name | Type | Description | Default |
|---|---|---|---|
x | Tensor | Input tensor (only used for dtype inference). | required |
position_ids | Tensor | Either 2-D scalar | required |
Returns:
| Type | Description |
|---|---|
tuple[Tensor, Tensor] | cos, sin: Rotary embeddings, each of shape (batch_size, seq_len, dim). |
BaseNavitOutputWithPast dataclass ¶
Bases: ModelOutput
packed_query_sequence class-attribute instance-attribute ¶
MLPconnector ¶
Bases: Module
fc1 instance-attribute ¶
fc1 = ColumnParallelLinear(
input_dim,
output_dim,
bias=True,
gather_output=False,
quant_config=quant_config,
prefix=f"{prefix}.fc1",
)
fc2 instance-attribute ¶
fc2 = RowParallelLinear(
output_dim,
output_dim,
bias=True,
input_is_parallel=True,
quant_config=quant_config,
prefix=f"{prefix}.fc2",
)
NaiveCache ¶
PackedAttentionMoT ¶
Bases: Module
Packed attention with Mixture-of-Tokens routing for understanding/generation.
Uses MoTQKVParallelLinear and MoTRowParallelLinear for tensor parallelism. Text and vae weights are held within the same MoT layer (text on self, vae on self.gen_exp). Token routing is driven by text_indices / vae_indices.
attn_causal instance-attribute ¶
attn_causal = DiffusionAttention(
num_heads=self.total_num_heads,
head_size=self.head_dim,
softmax_scale=1.0 / self.head_dim**0.5,
causal=True,
num_kv_heads=self.total_num_kv_heads,
)
attn_noncausal instance-attribute ¶
attn_noncausal = DiffusionAttention(
num_heads=self.total_num_heads,
head_size=self.head_dim,
softmax_scale=1.0 / self.head_dim**0.5,
causal=False,
num_kv_heads=self.total_num_kv_heads,
)
k_norm instance-attribute ¶
k_norm = MoTRMSNorm(
self.head_dim, head_norm=True, eps=config.rms_norm_eps
)
o_proj instance-attribute ¶
o_proj = MoTRowParallelLinear(
input_size=self.total_num_heads * self.head_dim,
output_size=self.hidden_size,
bias=False,
quant_config=quant_config,
prefix=f"{prefix}.o_proj",
)
q_norm instance-attribute ¶
q_norm = MoTRMSNorm(
self.head_dim, head_norm=True, eps=config.rms_norm_eps
)
qkv_proj instance-attribute ¶
qkv_proj = MoTQKVParallelLinear(
self.hidden_size,
self.head_dim,
self.total_num_heads,
self.total_num_kv_heads,
bias=True,
vae_bias=True,
quant_config=quant_config,
prefix=f"{prefix}.qkv_proj",
)
forward ¶
forward(
packed_query_sequence: Tensor,
query_lens: Tensor,
packed_query_position_embeddings: Tensor,
past_key_values: NaiveCache | None = None,
update_past_key_values=True,
is_causal=True,
mode="und",
packed_vae_token_indexes=None,
packed_text_indexes=None,
)
PositionEmbedding ¶
Bases: Module
pos_embed instance-attribute ¶
pos_embed = nn.Parameter(
torch.zeros(max_num_patch_per_side**2, hidden_size),
requires_grad=False,
)
Qwen2MoTConfig ¶
Bases: Qwen2Config
Configuration for Qwen2MoT (Mixture of Tokens) model.
This is fundamentally different from Qwen2, hence the distinct name.
Qwen2MoTDecoderLayer ¶
Bases: Module
input_layernorm instance-attribute ¶
input_layernorm = MoTRMSNorm(
config.hidden_size, eps=config.rms_norm_eps
)
mlp instance-attribute ¶
mlp = BagelMLP(
config.hidden_size,
config.intermediate_size,
config.hidden_act,
quant_config=quant_config,
prefix=f"{prefix}.mlp",
)
mlp_moe_gen instance-attribute ¶
mlp_moe_gen = BagelMLP(
config.hidden_size,
config.intermediate_size,
config.hidden_act,
quant_config=quant_config,
prefix=f"{prefix}.mlp_moe_gen",
)
post_attention_layernorm instance-attribute ¶
post_attention_layernorm = MoTRMSNorm(
config.hidden_size, eps=config.rms_norm_eps
)
self_attn instance-attribute ¶
self_attn = attn_module(
config,
layer_idx,
parallel_config=parallel_config,
quant_config=quant_config,
prefix=f"{prefix}.self_attn",
)
forward ¶
forward(
hidden_states: Tensor,
encoder_hidden_states: Tensor | None = None,
packed_query_sequence: Tensor | None = None,
query_lens: Tensor = None,
packed_query_position_embeddings: Tensor = None,
past_key_values: NaiveCache | None = None,
update_past_key_values=True,
is_causal=True,
mode="und",
packed_vae_token_indexes=None,
packed_text_indexes=None,
) -> BaseNavitOutputWithPast
Qwen2MoTForCausalLM ¶
Bases: Qwen2PreTrainedModel
lm_head instance-attribute ¶
model instance-attribute ¶
model = Qwen2MoTModel(
config,
parallel_config=parallel_config,
quant_config=quant_config,
prefix=f"{prefix}.model",
)
forward ¶
forward(
packed_query_sequence: Tensor | None = None,
query_lens: Tensor | None = None,
packed_query_position_ids: Tensor | None = None,
past_key_values: NaiveCache | None = None,
update_past_key_values=True,
is_causal=True,
mode="und",
packed_vae_token_indexes=None,
packed_text_indexes=None,
packed_text_ids: Tensor | None = None,
return_embeddings_only: bool = False,
) -> BaseNavitOutputWithPast
load_weights ¶
Load weights for MoT parallel layers.
Stacked parameter remapping (checkpoint name → model parameter): - q/k/v_proj → qkv_proj (text, shard q/k/v) - q/k/v_proj_moe_gen → qkv_proj.gen_exp (gen, shard q/k/v)
Direct remapping (no shard dimension): - o_proj_moe_gen → o_proj.gen_exp - {norm}_moe_gen.weight → {norm}.gen_weight (all MoTRMSNorm layers)
Text norm weights (input_layernorm.weight, q_norm.weight, etc.) and other names (embed_tokens, lm_head) pass through unchanged.
Qwen2MoTModel ¶
Bases: Qwen2PreTrainedModel
embed_tokens instance-attribute ¶
layers instance-attribute ¶
layers = nn.ModuleList(
[
(
Qwen2MoTDecoderLayer(
config,
layer_idx,
attn_module=PackedAttentionMoT,
parallel_config=parallel_config,
quant_config=quant_config,
prefix=f"{prefix}.layers.{layer_idx}",
)
)
for layer_idx in (range(config.num_hidden_layers))
]
)
forward ¶
forward(
packed_query_sequence: Tensor | None = None,
query_lens: Tensor | None = None,
packed_query_position_ids: Tensor | None = None,
past_key_values: NaiveCache | None = None,
update_past_key_values=True,
is_causal=True,
mode="und",
packed_vae_token_indexes=None,
packed_text_indexes=None,
packed_text_ids: Tensor | None = None,
return_embeddings_only: bool = False,
) -> BaseNavitOutputWithPast
TimestepEmbedder ¶
Bases: Module
Embeds scalar timesteps into vector representations.
get_1d_sincos_pos_embed_from_grid ¶
embed_dim: output dimension for each position pos: a list of positions to be encoded: size (M,) out: (M, D)
get_2d_sincos_pos_embed ¶
get_flattened_position_ids_extrapolate ¶
patchify ¶
imgs: (N, 3, H, W) or (3, H, W) x: (N, L, patch_size2 *3) or (L, patch_size2 *3)