Skip to content

vllm_omni.diffusion.models.ming_flash_omni.condition_encoder

Ming-flash-omni-2.0 condition encoder for image generation.

Pipeline (runs inside the imagegen stage):

thinker hidden states at query-token positions     [B, N, 4096]
                     │
                     ▼ proj_in (Linear, bias=True)
                                                   [B, N, 1536]
                     │
                     ▼ Qwen2ForCausalLM connector (is_causal=False)
                       loaded from <checkpoint>/connector/
                                                   [B, N, 1536]
                     │
                     ▼ proj_out (Linear, bias=True)
                                                   [B, N, 2560]
                     │
                     ▼ F.normalize(dim=-1) × 1000  (text_encoder_norm)
                                                   [B, N, 2560]
                     │
                     ▼
        cap_feats consumed by ZImageTransformer2DModel

logger module-attribute

logger = getLogger(__name__)

MingConditionEncoder

Bases: Module

Wraps a Qwen2 connector + norm/projection, producing DiT condition embeds.

The connector is a Qwen2ForCausalLM loaded from the connector/ subfolder of the Ming checkpoint. We run its base model in a non-causal (bidirectional) mode, because the connector is used as an encoder over the pre-baked query-token hidden states, not as an autoregressive decoder.

Parameters:

Name Type Description Default
image_gen_config MingImageGenConfig

MingImageGenConfig from MingFlashOmniConfig.

required
thinker_hidden_size int

Hidden size of the thinker (BailingMoeV2) model. Used to build a proj_in layer when the connector embedding dim differs. For the released checkpoint this is 4096.

4096
device device | str | None

Placement for the module.

None
dtype dtype | None

Parameter dtype (typically bfloat16 / float16).

None

config instance-attribute

config = image_gen_config

connector instance-attribute

connector: Module | None = None

connector_hidden_size instance-attribute

connector_hidden_size: int | None = None

norm instance-attribute

norm: Module = Identity()

proj_in instance-attribute

proj_in: Module = Identity()

proj_out instance-attribute

proj_out: Module = Identity()

thinker_hidden_size instance-attribute

thinker_hidden_size = thinker_hidden_size

extra_repr

extra_repr() -> str

forward

forward(
    thinker_hidden_states: Tensor,
    attention_mask: Tensor | None = None,
) -> Tensor

Encode thinker hidden states into DiT condition embeddings.

Parameters:

Name Type Description Default
thinker_hidden_states Tensor

[B, N, thinker_hidden_size] — sliced at the learnable query-token positions by the stage input processor before being passed here.

required
attention_mask Tensor | None

Optional [B, N] mask. Defaults to all-ones.

None

Returns:

Type Description
Tensor

[B, N, diffusion_c_input_dim] condition tensor ready for the

Tensor

ZImage transformer's cap_feats input.

load_from_checkpoint

load_from_checkpoint(model_path: str | Path) -> None

Load the Qwen2 connector + optional projection/norm weights.

This uses HF transformers directly (not vllm's weight loader) because the connector is small (~1.5B params) and only runs once per request as an encoder — vllm's distributed loading machinery is overkill.

zero_negative

zero_negative(cap_feats: Tensor) -> Tensor

Return a zero tensor shaped like cap_feats for CFG negatives.