Skip to content

vllm_omni.diffusion.models.ernie_image

ErnieImage diffusion model for vLLM-Omni.

This module implements ERNIE-Image text-to-image generation with: - ErnieImageTransformer2DModel: Custom DiT transformer - ErnieImagePipeline: Full generation pipeline

Modules:

Name Description
ernie_image_transformer
pipeline_ernie_image

ErnieImagePipeline

Bases: Module, CFGParallelMixin, SupportImageInput, ProgressBarMixin, DiffusionPipelineProfilerMixin

current_timestep property

current_timestep

default_sample_size instance-attribute

default_sample_size = 128

do_classifier_free_guidance property

do_classifier_free_guidance

guidance_scale property

guidance_scale

image_processor instance-attribute

image_processor = VaeImageProcessor(
    vae_scale_factor=vae_scale_factor
)

interrupt property

interrupt

is_distilled instance-attribute

is_distilled = is_distilled

num_timesteps property

num_timesteps

od_config instance-attribute

od_config = od_config

pe_model instance-attribute

pe_model = to(_execution_device)

pe_tokenizer instance-attribute

pe_tokenizer = from_pretrained(
    pe_base_path,
    subfolder="pe_tokenizer",
    local_files_only=True,
    trust_remote_code=True,
    use_fast=False,
)

scheduler instance-attribute

scheduler = from_pretrained(
    model,
    subfolder="scheduler",
    local_files_only=local_files_only,
)

support_image_input class-attribute instance-attribute

support_image_input = False

text_encoder instance-attribute

text_encoder = to(_execution_device)

tokenizer instance-attribute

tokenizer = from_pretrained(
    model,
    subfolder="tokenizer",
    local_files_only=local_files_only,
)

tokenizer_max_length instance-attribute

tokenizer_max_length = 512

transformer instance-attribute

transformer = ErnieImageTransformer2DModel(
    quant_config=quantization_config, **transformer_kwargs
)

use_pe instance-attribute

use_pe = True

vae instance-attribute

vae = to(_execution_device)

vae_scale_factor instance-attribute

vae_scale_factor = (
    2 ** len(block_out_channels)
    if getattr(self, "vae", None)
    else 16
)

weights_sources instance-attribute

weights_sources = [
    ComponentSource(
        model_or_path=model,
        subfolder="transformer",
        revision=None,
        prefix="transformer.",
        fall_back_to_pt=True,
    )
]

check_inputs

check_inputs(
    prompt,
    height,
    width,
    prompt_embeds=None,
    callback_on_step_end_tensor_inputs=None,
    guidance_scale=None,
)

encode_prompt

encode_prompt(
    prompt: str | list[str],
    device: device,
    num_images_per_prompt: int = 1,
    width: int = 1024,
    height: int = 1024,
    apply_pe: bool = True,
) -> list[Tensor]

forward

forward(
    req: OmniDiffusionRequest,
    prompt: str | list[str] | None = None,
    negative_prompt: str | list[str] | None = "",
    height: int = 1024,
    width: int = 1024,
    num_inference_steps: int = 50,
    guidance_scale: float = 4.0,
    num_images_per_prompt: int = 1,
    generator: Generator | None = None,
    latents: Tensor | None = None,
    prompt_embeds: list[FloatTensor] | None = None,
    negative_prompt_embeds: list[FloatTensor] | None = None,
    output_type: str = "pil",
    return_dict: bool = True,
    callback_on_step_end: Callable[[int, int, dict], None]
    | None = None,
    callback_on_step_end_tensor_inputs: list[str] = [
        "latents"
    ],
) -> DiffusionOutput

load_weights

load_weights(
    weights: Iterable[tuple[str, Tensor]],
) -> set[str]

ErnieImageTransformer2DModel

Bases: Module

adaLN_modulation instance-attribute

adaLN_modulation = Sequential(
    SiLU(), Linear(hidden_size, 6 * hidden_size)
)

config instance-attribute

config = SimpleNamespace(
    patch_size=patch_size,
    in_channels=in_channels,
    out_channels=out_channels,
    num_layers=num_layers,
    num_attention_heads=num_attention_heads,
    ffn_hidden_size=ffn_hidden_size,
    hidden_size=hidden_size,
    text_in_dim=text_in_dim,
    rope_theta=rope_theta,
    rope_axes_dim=rope_axes_dim,
    eps=eps,
    qk_layernorm=qk_layernorm,
)

dtype property

dtype: dtype

final_linear instance-attribute

final_linear = Linear(
    hidden_size, patch_size * patch_size * out_channels
)

final_norm instance-attribute

final_norm = ErnieImageAdaLNContinuous(hidden_size, eps)

gradient_checkpointing instance-attribute

gradient_checkpointing = False

head_dim instance-attribute

head_dim = hidden_size // num_attention_heads

hidden_size instance-attribute

hidden_size = hidden_size

in_channels instance-attribute

in_channels = in_channels

layers instance-attribute

layers = ModuleList(
    [
        (
            ErnieImageSharedAdaLNBlock(
                parallel_config=parallel_config,
                hidden_size=hidden_size,
                num_heads=num_attention_heads,
                ffn_hidden_size=ffn_hidden_size,
                eps=eps,
                qk_layernorm=qk_layernorm,
                quant_config=quant_config,
            )
        )
        for _ in (range(num_layers))
    ]
)

num_attention_heads instance-attribute

num_attention_heads = num_attention_heads

num_layers instance-attribute

num_layers = num_layers

out_channels instance-attribute

out_channels = out_channels or in_channels

parallel_config instance-attribute

parallel_config = parallel_config

patch_size instance-attribute

patch_size = patch_size

pos_embed instance-attribute

pos_embed = ErnieImageEmbedND3(
    dim=head_dim, theta=rope_theta, axes_dim=rope_axes_dim
)

text_in_dim instance-attribute

text_in_dim = text_in_dim

text_proj instance-attribute

text_proj = (
    Linear(text_in_dim, hidden_size, bias=False)
    if text_in_dim != hidden_size
    else None
)

time_embedding instance-attribute

time_embedding = TimestepEmbedding(hidden_size, hidden_size)

time_proj instance-attribute

time_proj = Timesteps(
    hidden_size,
    flip_sin_to_cos=False,
    downscale_freq_shift=0,
)

unified_prepare instance-attribute

unified_prepare = UnifiedPrepare(
    x_embedder, text_proj, pos_embed
)

x_embedder instance-attribute

x_embedder = ErnieImagePatchEmbedDynamic(
    in_channels, hidden_size, patch_size
)

forward

forward(
    hidden_states: Tensor,
    timestep: Tensor,
    text_bth: Tensor,
    text_lens: Tensor,
    return_dict: bool = True,
) -> Tensor | Transformer2DModelOutput

load_weights

load_weights(
    weights: Iterable[tuple[str, Tensor]],
) -> set[str]