vllm_omni.diffusion.models.bagel.bagel_transformer ¶
Bagel ¶
Bases: Module
connector instance-attribute ¶
connector = MLPconnector(
vit_hidden_size,
hidden_size,
connector_act,
quant_config=quant_config,
prefix=f"{prefix}.connector",
)
get_flattened_position_ids instance-attribute ¶
get_flattened_position_ids = (
get_flattened_position_ids_extrapolate
)
latent_pos_embed instance-attribute ¶
latent_pos_embed = PositionEmbedding(
max_latent_size, hidden_size
)
vit_max_num_patch_per_side instance-attribute ¶
vit_pos_embed instance-attribute ¶
vit_pos_embed = PositionEmbedding(
vit_max_num_patch_per_side, hidden_size
)
forward ¶
forward(
x_t: Tensor,
timestep: LongTensor,
packed_vae_token_indexes: LongTensor,
packed_vae_position_ids: LongTensor,
packed_text_ids: LongTensor,
packed_text_indexes: LongTensor,
packed_position_ids: LongTensor,
packed_seqlens: IntTensor,
past_key_values: NaiveCache,
cfg_renorm_min: float = 0.0,
cfg_renorm_type: str = "global",
cfg_text_scale: float = 1.0,
cfg_img_scale: float = 1.0,
cfg_branches: dict | None = None,
)
forward_cache_update_text ¶
forward_cache_update_text(
past_key_values: NaiveCache,
packed_text_ids: IntTensor,
packed_text_position_ids: LongTensor,
text_token_lens: LongTensor,
)
forward_cache_update_vae ¶
forward_cache_update_vae(
vae_model,
past_key_values: NaiveCache,
padded_images: Tensor,
patchified_vae_latent_shapes: list,
packed_vae_position_ids: LongTensor,
packed_timesteps: Tensor,
packed_vae_token_indexes: LongTensor,
packed_text_ids: LongTensor,
packed_text_indexes: LongTensor,
packed_position_ids: LongTensor,
packed_seqlens: IntTensor,
)
forward_cache_update_vit ¶
forward_cache_update_vit(
past_key_values: NaiveCache,
packed_text_ids: LongTensor,
packed_text_indexes: LongTensor,
packed_vit_tokens: Tensor,
packed_vit_token_indexes: LongTensor,
packed_vit_position_ids: LongTensor,
vit_token_seqlens: IntTensor,
packed_position_ids: LongTensor,
packed_seqlens: IntTensor,
)
forward_single_branch ¶
forward_single_branch(
x_t: Tensor,
timestep: LongTensor,
packed_vae_token_indexes: LongTensor,
packed_vae_position_ids: LongTensor,
packed_text_ids: LongTensor,
packed_text_indexes: LongTensor,
packed_position_ids: LongTensor,
packed_seqlens: IntTensor,
past_key_values: NaiveCache,
) -> Tensor
Run a single-branch forward pass (no CFG batching).
Used by CFG parallel mode where each rank computes one branch. Returns the velocity v_t for the given branch. Supports Ulysses / Ring SP when parallel_config.sequence_parallel_size > 1.
generate_image ¶
generate_image(
packed_text_ids: LongTensor,
packed_text_indexes: LongTensor,
packed_init_noises: Tensor,
packed_vae_position_ids: LongTensor,
packed_vae_token_indexes: LongTensor,
packed_seqlens: IntTensor,
packed_position_ids: LongTensor,
past_key_values: NaiveCache,
num_timesteps: int = 24,
timestep_shift: float = 1.0,
cfg_renorm_min: float = 0.0,
cfg_renorm_type: str = "global",
cfg_interval: tuple[float, float] = [0, 1],
cfg_text_scale: float = 1.0,
cfg_text_packed_position_ids: LongTensor | None = None,
cfg_text_past_key_values: NaiveCache | None = None,
cfg_img_scale: float = 1.0,
cfg_img_packed_position_ids: LongTensor | None = None,
cfg_img_past_key_values: NaiveCache | None = None,
return_trajectory_latents: bool = False,
scheduler: object | None = None,
scheduler_kwargs: dict | None = None,
frame_condition_token_indexes: LongTensor | None = None,
)
generate_text ¶
generate_text(
past_key_values: NaiveCache,
packed_start_tokens: LongTensor,
packed_query_position_ids: LongTensor,
max_length: int,
do_sample: bool = False,
temperature: float = 1.0,
end_token_id: int | None = None,
)
Autoregressive text generation (ported from original BAGEL).
Decodes tokens one at a time, appending to past_key_values until max_length is reached or end_token_id is generated.
prepare_start_tokens ¶
Prepare start tokens for autoregressive text generation.
Ported from the original BAGEL Bagel.prepare_start_tokens.
prepare_vae_images ¶
prepare_vit_images ¶
BagelMLP ¶
Bases: Module
down_proj instance-attribute ¶
down_proj = RowParallelLinear(
intermediate_size,
hidden_size,
input_is_parallel=True,
bias=False,
quant_config=quant_config,
prefix=f"{prefix}.down_proj",
)
gate_up_proj instance-attribute ¶
gate_up_proj = MergedColumnParallelLinear(
hidden_size,
[intermediate_size, intermediate_size],
bias=False,
quant_config=quant_config,
prefix=f"{prefix}.gate_up_proj",
)
BagelRotaryEmbedding ¶
Bases: Module
Standalone rotary embedding that generates cos/sin from position ids.
Replaces HuggingFace's Qwen2RotaryEmbedding while preserving full rope_scaling support. When config.rope_scaling is set (e.g. linear, dynamic-NTK, YaRN, …), we delegate the inv_freq / attention_scaling computation to HF's ROPE_INIT_FUNCTIONS so that the frequency basis and scaling factor are identical to the original checkpoint.
For Qwen2.5-VL-style multimodal RoPE (rope_scaling.rope_type == "mrope") the inv_freq basis is the standard default-rope one; the difference is that position ids are 3-D (t, h, w) per token and the mrope_section describes how the head dimension is split across axes. This module accepts either 2-D scalar position ids (B, S) or 3-D multimodal position ids (B, 3, S) and dispatches accordingly so the same module works for both BAGEL (1-D rope) and Lance (Qwen2.5-VL mrope). This module has no learnable parameters.
forward ¶
forward(
x: Tensor, position_ids: Tensor
) -> tuple[Tensor, Tensor]
Generate cos/sin embeddings for given position ids.
Parameters:
| Name | Type | Description | Default |
|---|---|---|---|
x | Tensor | Input tensor (only used for dtype inference). | required |
position_ids | Tensor | Either 2-D scalar | required |
Returns:
| Type | Description |
|---|---|
tuple[Tensor, Tensor] | cos, sin: Rotary embeddings, each of shape (batch_size, seq_len, dim). |
BaseNavitOutputWithPast dataclass ¶
Bases: ModelOutput
packed_query_sequence class-attribute instance-attribute ¶
MLPconnector ¶
Bases: Module
fc1 instance-attribute ¶
fc1 = ColumnParallelLinear(
input_dim,
output_dim,
bias=True,
gather_output=False,
quant_config=quant_config,
prefix=f"{prefix}.fc1",
)
fc2 instance-attribute ¶
fc2 = RowParallelLinear(
output_dim,
output_dim,
bias=True,
input_is_parallel=True,
quant_config=quant_config,
prefix=f"{prefix}.fc2",
)
NaiveCache ¶
PackedAttentionMoT ¶
Bases: Module
Packed attention with Mixture-of-Tokens routing for understanding/generation.
Uses vLLM's QKVParallelLinear and RowParallelLinear for tensor parallelism support, following the same pattern as vLLM's Qwen2Attention.
The q/k/v projections are stacked into a single QKVParallelLinear: - qkv_proj : stacks q_proj + k_proj + v_proj (understanding + gen text) - qkv_proj_moe_gen : stacks q_proj_moe_gen + k_proj_moe_gen + v_proj_moe_gen (gen vae)
attn_causal instance-attribute ¶
attn_causal = Attention(
num_heads=total_num_heads,
head_size=head_dim,
softmax_scale=1.0 / head_dim**0.5,
causal=True,
num_kv_heads=total_num_kv_heads,
)
attn_noncausal instance-attribute ¶
attn_noncausal = Attention(
num_heads=total_num_heads,
head_size=head_dim,
softmax_scale=1.0 / head_dim**0.5,
causal=False,
num_kv_heads=total_num_kv_heads,
)
o_proj instance-attribute ¶
o_proj = RowParallelLinear(
total_num_heads * head_dim,
hidden_size,
bias=False,
quant_config=quant_config,
prefix=f"{prefix}.o_proj",
)
o_proj_moe_gen instance-attribute ¶
o_proj_moe_gen = RowParallelLinear(
total_num_heads * head_dim,
hidden_size,
bias=False,
quant_config=quant_config,
prefix=f"{prefix}.o_proj_moe_gen",
)
qkv_proj instance-attribute ¶
qkv_proj = QKVParallelLinear(
hidden_size,
head_dim,
total_num_heads,
total_num_kv_heads,
bias=True,
quant_config=quant_config,
prefix=f"{prefix}.qkv_proj",
)
qkv_proj_moe_gen instance-attribute ¶
qkv_proj_moe_gen = QKVParallelLinear(
hidden_size,
head_dim,
total_num_heads,
total_num_kv_heads,
bias=True,
quant_config=quant_config,
prefix=f"{prefix}.qkv_proj_moe_gen",
)
forward ¶
forward(
packed_query_sequence: Tensor,
query_lens: Tensor,
packed_query_position_embeddings: Tensor,
past_key_values: NaiveCache | None = None,
update_past_key_values=True,
is_causal=True,
mode="und",
packed_vae_token_indexes=None,
packed_text_indexes=None,
)
Qwen2MoTConfig ¶
Bases: Qwen2Config
Configuration for Qwen2MoT (Mixture of Tokens) model.
This is fundamentally different from Qwen2, hence the distinct name.
Qwen2MoTDecoderLayer ¶
Bases: Module
input_layernorm_moe_gen instance-attribute ¶
mlp instance-attribute ¶
mlp = BagelMLP(
hidden_size,
intermediate_size,
hidden_act,
quant_config=quant_config,
prefix=f"{prefix}.mlp",
)
mlp_moe_gen instance-attribute ¶
mlp_moe_gen = BagelMLP(
hidden_size,
intermediate_size,
hidden_act,
quant_config=quant_config,
prefix=f"{prefix}.mlp_moe_gen",
)
post_attention_layernorm instance-attribute ¶
post_attention_layernorm_moe_gen instance-attribute ¶
self_attn instance-attribute ¶
self_attn = attn_module(
config,
layer_idx,
parallel_config=parallel_config,
quant_config=quant_config,
prefix=f"{prefix}.self_attn",
)
forward ¶
forward(
hidden_states: Tensor,
encoder_hidden_states: Tensor | None = None,
packed_query_sequence: Tensor | None = None,
query_lens: Tensor = None,
packed_query_position_embeddings: Tensor = None,
past_key_values: NaiveCache | None = None,
update_past_key_values=True,
is_causal=True,
mode="und",
packed_vae_token_indexes=None,
packed_text_indexes=None,
) -> BaseNavitOutputWithPast
Qwen2MoTForCausalLM ¶
Bases: Qwen2PreTrainedModel
model instance-attribute ¶
model = Qwen2MoTModel(
config,
parallel_config=parallel_config,
quant_config=quant_config,
prefix=f"{prefix}.model",
)
forward ¶
forward(
packed_query_sequence: Tensor | None = None,
query_lens: Tensor | None = None,
packed_query_position_ids: Tensor | None = None,
past_key_values: NaiveCache | None = None,
update_past_key_values=True,
is_causal=True,
mode="und",
packed_vae_token_indexes=None,
packed_text_indexes=None,
packed_text_ids: Tensor | None = None,
return_embeddings_only: bool = False,
) -> BaseNavitOutputWithPast
load_weights ¶
Load weights for vLLM parallel layers.
Handles stacked parameter remapping for QKVParallelLinear
- q_proj, k_proj, v_proj -> qkv_proj (shard ids: q, k, v)
- q_proj_moe_gen, k_proj_moe_gen, v_proj_moe_gen -> qkv_proj_moe_gen
Other parallel layers (gate_proj, up_proj, down_proj, embed_tokens, etc.) keep HF checkpoint names and use weight_loader for TP sharding.
Qwen2MoTModel ¶
Bases: Qwen2PreTrainedModel
layers instance-attribute ¶
layers = ModuleList(
[
(
Qwen2MoTDecoderLayer(
config,
layer_idx,
attn_module=PackedAttentionMoT,
parallel_config=parallel_config,
quant_config=quant_config,
prefix=f"{prefix}.layers.{layer_idx}",
)
)
for layer_idx in (range(num_hidden_layers))
]
)
forward ¶
forward(
packed_query_sequence: Tensor | None = None,
query_lens: Tensor | None = None,
packed_query_position_ids: Tensor | None = None,
past_key_values: NaiveCache | None = None,
update_past_key_values=True,
is_causal=True,
mode="und",
packed_vae_token_indexes=None,
packed_text_indexes=None,
packed_text_ids: Tensor | None = None,
return_embeddings_only: bool = False,
) -> BaseNavitOutputWithPast
TimestepEmbedder ¶
Bases: Module
Embeds scalar timesteps into vector representations.
get_1d_sincos_pos_embed_from_grid ¶
embed_dim: output dimension for each position pos: a list of positions to be encoded: size (M,) out: (M, D)
get_2d_sincos_pos_embed ¶
get_flattened_position_ids_extrapolate ¶
patchify ¶
imgs: (N, 3, H, W) or (3, H, W) x: (N, L, patch_size2 *3) or (L, patch_size2 *3)