Skip to content

vllm_omni.diffusion.models.ming_flash_omni.byte5_encoder

ByT5 glyph/text encoder for Ming-flash-omni-2.0 image generation.

Ported from Ming's load_byte5_and_byte5_tokenizer + T5EncoderBlockByT5Mapper + get_condition_embeds_for_image_gen byte5 branch. The released checkpoint's byte5 weights were trained with font/color special tokens, so we replicate that vocabulary extension before loading — otherwise byte5_model.pt would shape-mismatch at the embedding layer.

Typical forward: takes a list of user-supplied prompt strings (possibly with <cn-font-N> / <color-N> markers), returns [B, byte5_max_length, diffusion_c_input_dim] features ready to be concatenated onto cap_feats along the sequence dimension.

logger module-attribute

logger = getLogger(__name__)

MingByT5Encoder

Bases: Module

Bundles byte5 tokenizer + T5 encoder + T5EncoderBlockByT5Mapper.

Build with MingByT5Encoder.from_checkpoint(<model>/byte5) when the checkpoint ships byte5 weights; otherwise callers can skip this and the pipeline falls back to no-byte5 conditioning.

mapper instance-attribute

mapper = mapper

max_length instance-attribute

max_length = max_length

text_encoder instance-attribute

text_encoder = text_encoder

tokenizer instance-attribute

tokenizer = tokenizer

forward

forward(texts: list[str]) -> Tensor

Tokenize → T5 encode → mapper; masks out padded positions.

Returns [B, max_length, sdxl_channels]. Padded positions are zeroed so the downstream torch.cat with cap_feats doesn't inject garbage.

from_checkpoint classmethod

from_checkpoint(
    byte5_dir: Path, *, device: device, dtype: dtype
) -> MingByT5Encoder