Skip to content

Text-To-Image

Source https://github.com/vllm-project/vllm-omni/tree/main/examples/offline_inference/text_to_image.

Generate images from text prompts using vLLM-Omni's diffusion pipeline entrypoints.

  • text_to_image.py: command-line script for single image generation with advanced options.
  • gradio_demo.py: lightweight Gradio UI for interactive prompt/seed/CFG exploration.

Table of Contents

Overview

This folder provides several entrypoints for experimenting with text-to-image diffusion models using vLLM-Omni. Note that NextStep-1.1 has a different architecture, so it is treated differently regarding running arguments and pipeline.

Supported Models

Model Image Shape Peak VRAM (GiB) * Model Weights (GiB)
Qwen/Qwen-Image 1024 x 1024 60.0 53.7
Qwen/Qwen-Image-2512 1024 x 1024 60.0 53.7
Tongyi-MAI/Z-Image-Turbo 1024 x 1024 24.8 19.2
stepfun-ai/NextStep-1.1 512 x 512 71.8 28.1
meituan-longcat/LongCat-Image 1024 x 1024 71.2 27.3
AIDC-AI/Ovis-Image-7B 1024 x 1024 71.8 17.1
OmniGen2/OmniGen2 1024 x 1024 20.1 14.7
stabilityai/stable-diffusion-3.5-medium 1024 x 1024 20.1 15.6
black-forest-labs/FLUX.1-dev 1024 x 1024 33.9 31.4
black-forest-labs/FLUX.1-schnell 1024 x 1024 33.9 31.4
black-forest-labs/FLUX.2-klein-4B 1024 x 1024 72.7 14.9
black-forest-labs/FLUX.2-klein-9B 1024 x 1024 37.1 32.3
black-forest-labs/FLUX.2-dev 1024 x 1024 65.7 >80 (CPU offload required)
HunyuanImage-3.0 1024 x 1024 80.0 (TP≥3) 160

Info

*Peak VRAM: based on basic single-card usage, batch size =1, without any acceleration/optimization features. FLUX.2-dev requires --enable-cpu-offload on a single 80 GiB GPU.

Default model: Qwen/Qwen-Image

Quick Start

Python API

Single-prompt generation:

from vllm_omni.entrypoints.omni import Omni

if __name__ == "__main__":
    omni = Omni(model="Qwen/Qwen-Image")
    prompt = "a cup of coffee on the table"
    outputs = omni.generate(prompt)
    images = outputs[0].request_output.images
    images[0].save("coffee.png")

Local CLI Usage

python text_to_image.py \
  --model Qwen/Qwen-Image \
  --prompt "a cup of coffee on the table" \
  --output coffee.png

Key Arguments

Common arguments:

Argument Type Default Description
--prompt str "a cup of coffee on the table" Text description for image generation
--seed int 142 Integer seed for deterministic sampling
--negative-prompt str None Negative prompt for classifier-free conditional guidance
--cfg-scale float 4.0 True CFG scale (model-specific guidance strength)
--guidance-scale float 1.0 Classifier-free guidance scale
--num-images-per-prompt int 1 Number of images per prompt (saved as output, output_1, ...)
--num-inference-steps int 50 Diffusion sampling steps (more steps = higher quality, slower)
--height int 1024 Output image height in pixels
--width int 1024 Output image width in pixels
--output str "qwen_image_output.png" Path to save the generated image
--vae-use-slicing flag off Enable VAE slicing for memory optimization
--vae-use-tiling flag off Enable VAE tiling for memory optimization
--cfg-parallel-size int 1 Set to 2 to enable CFG Parallel
--ulysses-degree int 1 Ulysses sequence parallel degree for multi-GPU inference
--ring-degree int 1 Ring sequence parallel degree for hybrid Ulysses + Ring inference
--ulysses-mode str "strict" Ulysses SP mode: "strict" or "advanced_uaa"
--enable-cpu-offload flag off Enable CPU offloading for diffusion models
--lora-path str Path to PEFT LoRA adapter folder
--lora-scale float 1.0 Scale factor for LoRA weights
--use-system-prompt str None System prompt preset: en_unified, en_vanilla, en_recaption, en_think_recaption, dynamic, None, or custom text. Recommended: en_unified. Only for HunyuanImage-3.0.
--system-prompt str None Custom system prompt text. Only used when --use-system-prompt is set to custom. Only for HunyuanImage-3.0.

NextStep-1.1 specific arguments:

Argument Type Default Description
--guidance-scale-2 float 1.0 Secondary guidance scale (e.g. image-level CFG)
--timesteps-shift float 1.0 Timesteps shift parameter for sampling
--cfg-schedule str "constant" CFG schedule type: "constant" or "linear"
--use-norm flag off Apply layer normalization to sampled tokens

If you encounter OOM errors, try using --vae-use-slicing and --vae-use-tiling to reduce memory usage.

Qwen-Image currently publishes best-effort presets at 1328x1328, 1664x928, 928x1664, 1472x1140, 1140x1472, 1584x1056, and 1056x1584. Adjust --height/--width accordingly for the most reliable outcomes.

More CLI Examples

Tongyi Models

python text_to_image.py \
  --model Tongyi-MAI/Z-Image-Turbo \
  --prompt "a cup of coffee on the table" \
  --seed 42 \
  --guidance-scale 0.0 \
  --num-images-per-prompt 1 \
  --num-inference-steps 9 \
  --height 1024 \
  --width 1024 \
  --output outputs/coffee.png

Tongyi-MAI/Z-Image-Turbo is a distilled version of Z-Image. Distilled diffusion models usually require less number of inference steps (4~9), and Classifier-Free Guidance (CFG) is usually NOT applied. Similar distilled models are black-forest-labs/FLUX.2-klein-4B and black-forest-labs/FLUX.2-klein-9B.

Advanced UAA example (requires 2 GPUs):

python text_to_image.py \
  --model Tongyi-MAI/Z-Image-Turbo \
  --prompt "a cup of coffee on the table" \
  --ulysses-degree 2 \
  --ulysses-mode advanced_uaa \
  --height 1024 \
  --width 1024 \
  --output outputs/coffee_hybrid.png

NextStep Models

NextStep-1.1 supports extra arguments for dual-level CFG control:

python text_to_image.py \
  --model stepfun-ai/NextStep-1.1 \
  --prompt "A baby panda wearing an Iron Man mask, holding a board with 'NextStep-1' written on it" \
  --height 512 \
  --width 512 \
  --num-inference-steps 28 \
  --guidance-scale 7.5 \
  --guidance-scale-2 1.0 \
  --cfg-schedule constant \
  --output nextstep_output.png \
  --seed 42

FLUX.2-dev Models

To run FLUX.2-dev on a single GPU, --enable-cpu-offload is required because the model weights exceed 80 GiB:

python examples/offline_inference/text_to_image/text_to_image.py \
  --model black-forest-labs/FLUX.2-dev \
  --prompt "a lovely bunny holding a sign that says 'vllm-omni'" \
  --seed 42 \
  --tensor-parallel-size 1 \
  --num-images-per-prompt 1 \
  --num-inference-steps 50 \
  --guidance-scale 4.0 \
  --height 1024 \
  --width 1024 \
  --enable-cpu-offload \
  --output flux2-dev.png

Batch Requests (Multiple Prompts)

You can pass multiple prompts in a single generate call.

from vllm_omni.entrypoints.omni import Omni

if __name__ == "__main__":
    omni = Omni(model="Qwen/Qwen-Image")
    prompts = [
        "a cup of coffee on a table",
        "a toy dinosaur on a sandy beach",
        "a fox waking up in bed and yawning",
    ]
    outputs = omni.generate(prompts)
    for i, output in enumerate(outputs):
        output.request_output.images[0].save(f"{i}.jpg")

Info

Not all models support batch inference, and batch requesting mostly does not provide significant performance improvement. This feature is primarily for interface compatibility with vLLM and to allow for future improvements.

Info

For diffusion pipelines, the stage config field stage_args.[].runtime.max_batch_size is 1 by default, and the input list is sliced into single-item requests before feeding into the diffusion pipeline. For models that do internally support batched inputs, you can modify this configuration to let the model accept a longer batch of prompts.

Negative Prompts

vLLM-Omni supports dictionary prompts for models that accept negative prompts:

from vllm_omni.entrypoints.omni import Omni

if __name__ == "__main__":
    omni = Omni(model="Qwen/Qwen-Image")
    outputs = omni.generate([
        {
            "prompt": "a cup of coffee on a table",
            "negative_prompt": "low resolution"
        },
        {
            "prompt": "a toy dinosaur on a sandy beach",
            "negative_prompt": "cinematic, realistic"
        }
    ])
    for i, output in enumerate(outputs):
        output.request_output.images[0].save(f"{i}.jpg")

You can also pass a negative prompt via the CLI argument --negative-prompt:

python examples/offline_inference/text_to_image/text_to_image.py \
  --model Qwen/Qwen-Image \
  --prompt "a cup of coffee on a table" \
  --negative-prompt "low resolution, blurry" \
  --output coffee.png

Advanced Features

CFG Parallel

Set --cfg-parallel-size 2 to enable CFG Parallel for faster inference on multi-GPU setups. See more examples in the cfg_parallel user guide.

LoRA

This example supports PEFT-compatible LoRA (Low-Rank Adaptation) adapters for diffusion models. Pass --lora-path to use a LoRA adapter and optionally --lora-scale (default 1.0); omit it to use the base model only.

python text_to_image.py \
  --model Tongyi-MAI/Z-Image-Turbo \
  --prompt "A piece of cheesecake" \
  --lora-path /path/to/lora/ \
  --lora-scale 1.0 \
  --output output.png

LoRA adapters must be in PEFT format. A typical adapter directory structure:

lora_adapter/
├── adapter_config.json
└── adapter_model.safetensors

Web UI Demo

Launch the Gradio demo:

python gradio_demo.py --port 7862

Then open http://localhost:7862/ in your local browser to interact with the web UI.

Example materials

gradio_demo.py
import argparse
from functools import lru_cache

try:
    import gradio as gr
except ImportError:
    raise ImportError("gradio is required to run this demo. Install it with: pip install 'vllm-omni[demo]'") from None
import torch

from vllm_omni.entrypoints.omni import Omni
from vllm_omni.inputs.data import OmniDiffusionSamplingParams
from vllm_omni.outputs import OmniRequestOutput
from vllm_omni.platforms import current_omni_platform

ASPECT_RATIOS: dict[str, tuple[int, int]] = {
    "1:1": (1328, 1328),
    "16:9": (1664, 928),
    "9:16": (928, 1664),
    "4:3": (1472, 1140),
    "3:4": (1140, 1472),
    "3:2": (1584, 1056),
    "2:3": (1056, 1584),
}
ASPECT_RATIO_CHOICES = [f"{ratio} ({w}x{h})" for ratio, (w, h) in ASPECT_RATIOS.items()]


def parse_args() -> argparse.Namespace:
    parser = argparse.ArgumentParser(description="Gradio demo for Qwen-Image offline inference.")
    parser.add_argument("--model", default="Qwen/Qwen-Image", help="Diffusion model name or local path.")
    parser.add_argument(
        "--height",
        type=int,
        default=1328,
        help="Default image height (must match one of the supported presets).",
    )
    parser.add_argument(
        "--width",
        type=int,
        default=1328,
        help="Default image width (must match one of the supported presets).",
    )
    parser.add_argument("--default-prompt", default="a cup of coffee on the table", help="Initial prompt shown in UI.")
    parser.add_argument("--default-seed", type=int, default=42, help="Initial seed shown in UI.")
    parser.add_argument("--default-cfg-scale", type=float, default=4.0, help="Initial CFG scale shown in UI.")
    parser.add_argument(
        "--num-inference-steps",
        type=int,
        default=50,
        help="Default number of denoising steps shown in the UI.",
    )
    parser.add_argument("--ip", default="127.0.0.1", help="Host/IP for Gradio `launch`.")
    parser.add_argument("--port", type=int, default=7862, help="Port for Gradio `launch`.")
    parser.add_argument("--share", action="store_true", help="Share the Gradio demo publicly.")
    args = parser.parse_args()
    args.aspect_ratio_label = next(
        (ratio for ratio, dims in ASPECT_RATIOS.items() if dims == (args.width, args.height)),
        None,
    )
    if args.aspect_ratio_label is None:
        supported = ", ".join(f"{ratio} ({w}x{h})" for ratio, (w, h) in ASPECT_RATIOS.items())
        parser.error(f"Unsupported resolution {args.width}x{args.height}. Please pick one of: {supported}.")
    return args


@lru_cache(maxsize=1)
def get_omni(model_name: str) -> Omni:
    # Enable VAE memory optimizations on NPU
    vae_use_slicing = current_omni_platform.is_npu()
    vae_use_tiling = current_omni_platform.is_npu()
    return Omni(
        model=model_name,
        vae_use_slicing=vae_use_slicing,
        vae_use_tiling=vae_use_tiling,
    )


def build_demo(args: argparse.Namespace) -> gr.Blocks:
    omni = get_omni(args.model)

    def run_inference(
        prompt: str,
        seed_value: float,
        cfg_scale_value: float,
        resolution_choice: str,
        num_steps_value: float,
        num_images_choice: float,
    ):
        if not prompt or not prompt.strip():
            raise gr.Error("Please enter a non-empty prompt.")
        ratio_label = resolution_choice.split(" ", 1)[0]
        if ratio_label not in ASPECT_RATIOS:
            raise gr.Error(f"Unsupported aspect ratio: {ratio_label}")
        width, height = ASPECT_RATIOS[ratio_label]
        try:
            seed = int(seed_value)
            num_steps = int(num_steps_value)
            num_images = int(num_images_choice)
        except (TypeError, ValueError) as exc:
            raise gr.Error("Seed, inference steps, and number of images must be valid integers.") from exc
        if num_steps <= 0:
            raise gr.Error("Inference steps must be a positive integer.")
        if num_images not in {1, 2, 3, 4}:
            raise gr.Error("Number of images must be 1, 2, 3, or 4.")
        generator = torch.Generator(device=current_omni_platform.device_type).manual_seed(seed)
        outputs = omni.generate(
            prompt.strip(),
            OmniDiffusionSamplingParams(
                height=height,
                width=width,
                generator=generator,
                true_cfg_scale=float(cfg_scale_value),
                num_inference_steps=num_steps,
                num_outputs_per_prompt=num_images,
            ),
        )
        images_outputs = []
        for output in outputs:
            req_out = output.request_output
            if not isinstance(req_out, OmniRequestOutput) or not hasattr(req_out, "images"):
                raise ValueError("Invalid request_output structure or missing 'images' key")
            images = req_out.images
            if not images:
                raise ValueError("No images found in request_output")
            # Extend the list with individual images (not append the entire list)
            images_outputs.extend(images)
            if len(images_outputs) >= num_images:
                break
        # Return only the requested number of images
        return images_outputs[:num_images]

    with gr.Blocks(
        title="vLLM-Omni Web Serving Demo",
        css="""
        /* Left column button width */
        .left-column button {
            width: 100%;
        }
        /* Right preview area: fixed height, hide unnecessary buttons */
        .fixed-image {
            height: 660px;
            display: flex;
            flex-direction: column;
            justify-content: center;
            align-items: center;
        }
        .fixed-image .duplicate-button,
        .fixed-image .svelte-drgfj2 {
            display: none !important;
        }
        /* Gallery container: fill available space and center content */
        #image-gallery {
            width: 100%;
            height: 100%;
            display: flex;
            align-items: center;
            justify-content: center;
        }
        /* Gallery grid: center horizontally and vertically, set gap */
        #image-gallery .grid {
            display: flex;
            flex-wrap: wrap;
            justify-content: center;
            align-items: center;
            align-content: center;
            gap: 16px;
            width: 100%;
            height: 100%;
        }
        /* Gallery grid items: center content */
        #image-gallery .grid > div {
            display: flex;
            align-items: center;
            justify-content: center;
        }
        /* Gallery images: limit max height, maintain aspect ratio */
        .fixed-image img {
            max-height: 660px !important;
            width: auto !important;
            object-fit: contain;
        }
        """,
    ) as demo:
        gr.Markdown("# vLLM-Omni Web Serving Demo")
        gr.Markdown(f"**Model:** {args.model}")

        with gr.Row():
            with gr.Column(scale=1, elem_classes="left-column"):
                prompt_input = gr.Textbox(
                    label="Prompt",
                    value=args.default_prompt,
                    placeholder="Describe the image you want to generate...",
                    lines=5,
                )
                seed_input = gr.Number(label="Seed", value=args.default_seed, precision=0)
                cfg_input = gr.Number(label="CFG Scale", value=args.default_cfg_scale)
                steps_input = gr.Number(
                    label="Inference Steps",
                    value=args.num_inference_steps,
                    precision=0,
                    minimum=1,
                )
                aspect_dropdown = gr.Dropdown(
                    label="Aspect Ratio (W:H)",
                    choices=ASPECT_RATIO_CHOICES,
                    value=f"{args.aspect_ratio_label} ({ASPECT_RATIOS[args.aspect_ratio_label][0]}x{ASPECT_RATIOS[args.aspect_ratio_label][1]})",
                )
                num_images = gr.Dropdown(
                    label="Number of images",
                    choices=["1", "2", "3", "4"],
                    value="1",
                )
                generate_btn = gr.Button("Generate", variant="primary")
            with gr.Column(scale=2, elem_classes="fixed-image"):
                gallery = gr.Gallery(
                    label="Preview",
                    columns=2,
                    rows=2,
                    height=660,
                    allow_preview=True,
                    show_label=True,
                    elem_id="image-gallery",
                )

        generate_btn.click(
            fn=run_inference,
            inputs=[prompt_input, seed_input, cfg_input, aspect_dropdown, steps_input, num_images],
            outputs=gallery,
        )

    return demo


def main():
    args = parse_args()
    demo = build_demo(args)
    demo.launch(server_name=args.ip, server_port=args.port, share=args.share)


if __name__ == "__main__":
    main()
text_to_image.py
# SPDX-License-Identifier: Apache-2.0
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project

import argparse
import json
import time
from pathlib import Path
from typing import Any

import torch

from vllm_omni.diffusion.data import logger
from vllm_omni.entrypoints.omni import Omni
from vllm_omni.inputs.data import OmniDiffusionSamplingParams
from vllm_omni.lora.request import LoRARequest
from vllm_omni.lora.utils import stable_lora_int_id
from vllm_omni.platforms import current_omni_platform


def is_nextstep_model(model_name: str) -> bool:
    """Check if the model is a NextStep model by reading its config."""
    from vllm.transformers_utils.config import get_hf_file_to_dict

    try:
        cfg = get_hf_file_to_dict("config.json", model_name)
        if cfg and cfg.get("model_type") == "nextstep":
            return True
    except Exception:
        pass
    return False


def parse_profiler_config(value: str) -> dict[str, Any]:
    try:
        config = json.loads(value)
    except json.JSONDecodeError as e:
        raise argparse.ArgumentTypeError(f"--profiler-config must be valid JSON: {e}") from e
    if not isinstance(config, dict):
        raise argparse.ArgumentTypeError("--profiler-config must be a JSON object")
    return config


def parse_args() -> argparse.Namespace:
    parser = argparse.ArgumentParser(description="Generate an image with supported diffusion models.")
    parser.add_argument(
        "--model",
        default="Qwen/Qwen-Image",
        help="Diffusion model name or local path. Supported models: "
        "Qwen/Qwen-Image, Tongyi-MAI/Z-Image-Turbo, Qwen/Qwen-Image-2512, stepfun-ai/NextStep-1.1, "
        "black-forest-labs/FLUX.1-dev, black-forest-labs/FLUX.2-klein-9B, "
        "black-forest-labs/FLUX.2-dev, tencent/HunyuanImage-3.0-Instruct, "
        "meituan-longcat/LongCat-Image, OvisAI/Ovis-Image, "
        "stabilityai/stable-diffusion-3.5-medium, Tongyi-MAI/Z-Image-Turbo and etc.",
    )
    parser.add_argument(
        "--stage-configs-path",
        type=str,
        default=None,
        help="Path to a YAML file containing stage configurations for Omni.",
    )
    parser.add_argument("--prompt", default="a cup of coffee on the table", help="Text prompt for image generation.")
    parser.add_argument(
        "--negative-prompt",
        default=None,
        help="negative prompt for classifier-free conditional guidance.",
    )
    parser.add_argument("--seed", type=int, default=142, help="Random seed for deterministic results.")
    parser.add_argument(
        "--cfg-scale",
        type=float,
        default=4.0,
        help="True classifier-free guidance scale specific to Qwen-Image.",
    )
    parser.add_argument(
        "--guidance-scale",
        type=float,
        default=4.0,
        help="Classifier-free guidance scale. HunyuanImage3 recommends 4.0-5.0.",
    )
    parser.add_argument("--height", type=int, default=1024, help="Height of generated image.")
    parser.add_argument("--width", type=int, default=1024, help="Width of generated image.")
    parser.add_argument(
        "--output",
        type=str,
        default="qwen_image_output.png",
        help="Path to save the generated image (PNG).",
    )
    parser.add_argument(
        "--num-images-per-prompt",
        type=int,
        default=1,
        help="Number of images to generate for the given prompt.",
    )
    parser.add_argument(
        "--num-inference-steps",
        type=int,
        default=50,
        help="Number of denoising steps for the diffusion sampler.",
    )
    parser.add_argument(
        "--cache-backend",
        type=str,
        default=None,
        choices=["cache_dit", "tea_cache"],
        help=(
            "Cache backend to use for acceleration. "
            "Options: 'cache_dit' (DBCache + SCM + TaylorSeer), 'tea_cache' (Timestep Embedding Aware Cache). "
            "Default: None (no cache acceleration)."
        ),
    )
    parser.add_argument(
        "--enable-cache-dit-summary",
        action="store_true",
        help="Enable cache-dit summary logging after diffusion forward passes.",
    )
    parser.add_argument(
        "--ulysses-degree",
        type=int,
        default=1,
        help="Number of GPUs used for ulysses sequence parallelism.",
    )
    parser.add_argument(
        "--ulysses-mode",
        type=str,
        default="strict",
        choices=["strict", "advanced_uaa"],
        help="Ulysses sequence-parallel mode: 'strict' (divisibility required) or 'advanced_uaa' (UAA).",
    )
    parser.add_argument(
        "--ring-degree",
        type=int,
        default=1,
        help="Number of GPUs used for ring sequence parallelism.",
    )
    parser.add_argument(
        "--cfg-parallel-size",
        type=int,
        default=1,
        choices=[1, 2],
        help="Number of GPUs used for classifier free guidance parallel size.",
    )
    parser.add_argument(
        "--enforce-eager",
        action="store_true",
        help="Disable torch.compile and force eager execution.",
    )
    parser.add_argument(
        "--enable-cpu-offload",
        action="store_true",
        help="Enable CPU offloading for diffusion models.",
    )
    parser.add_argument(
        "--enable-layerwise-offload",
        action="store_true",
        help="Enable layerwise (blockwise) offloading on DiT modules.",
    )
    parser.add_argument(
        "--use-hsdp",
        action="store_true",
        help="Enable HSDP (Hybrid Sharded Data Parallel) for diffusion models.",
    )
    parser.add_argument(
        "--hsdp-shard-size",
        type=int,
        default=1,
        help="Number of GPUs to shard weights across for HSDP.",
    )
    parser.add_argument(
        "--hsdp-replicate-size",
        type=int,
        default=1,
        help="Number of HSDP replica groups.",
    )
    parser.add_argument(
        "--quantization",
        type=str,
        default=None,
        choices=["fp8", "int8", "gguf"],
        help="Quantization method for the transformer. "
        "Options: 'fp8' (FP8 W8A8 on Ada/Hopper, weight-only on older GPUs), 'int8' (Int8 W8A8), 'gguf' (GGUF quantized weights). "
        "Default: None (no quantization, uses BF16).",
    )
    parser.add_argument(
        "--gguf-model",
        type=str,
        default=None,
        help=("GGUF file path or HF reference for transformer weights. Required when --quantization gguf is set."),
    )
    parser.add_argument(
        "--ignored-layers",
        type=str,
        default=None,
        help="Comma-separated list of layer name patterns to skip quantization. "
        "Only used when --quantization is set. "
        "Available layers: to_qkv, to_out, add_kv_proj, to_add_out, img_mlp, txt_mlp, proj_out. "
        "Example: --ignored-layers 'add_kv_proj,to_add_out'",
    )
    parser.add_argument(
        "--vae-use-slicing",
        action="store_true",
        help="Enable VAE slicing for memory optimization.",
    )
    parser.add_argument(
        "--vae-use-tiling",
        action="store_true",
        help="Enable VAE tiling for memory optimization.",
    )
    parser.add_argument(
        "--tensor-parallel-size",
        type=int,
        default=1,
        help="Number of GPUs used for tensor parallelism (TP) inside the DiT.",
    )
    parser.add_argument(
        "--enable-expert-parallel",
        action="store_true",
        help="Enable expert parallelism for MoE layers.",
    )
    parser.add_argument(
        "--lora-path",
        type=str,
        default=None,
        help="Path to LoRA adapter folder (PEFT format). Loaded at initialization and used for generation.",
    )
    parser.add_argument(
        "--lora-scale",
        type=float,
        default=1.0,
        help="Scale factor for LoRA weights (default: 1.0).",
    )
    parser.add_argument(
        "--vae-patch-parallel-size",
        type=int,
        default=1,
        help="Number of ranks used for VAE patch/tile parallelism (decode/encode).",
    )
    # NextStep-1.1 specific arguments
    parser.add_argument(
        "--guidance-scale-2",
        type=float,
        default=1.0,
        help="Secondary guidance scale (e.g. image-level CFG for NextStep-1.1).",
    )
    parser.add_argument(
        "--timesteps-shift",
        type=float,
        default=1.0,
        help="[NextStep-1.1 only] Timesteps shift parameter for sampling.",
    )
    parser.add_argument(
        "--cfg-schedule",
        type=str,
        default="constant",
        choices=["constant", "linear"],
        help="[NextStep-1.1 only] CFG schedule type.",
    )
    parser.add_argument(
        "--use-norm",
        action="store_true",
        help="[NextStep-1.1 only] Apply layer normalization to sampled tokens.",
    )
    parser.add_argument(
        "--enable-diffusion-pipeline-profiler",
        action="store_true",
        help="Enable diffusion pipeline profiler to display stage durations.",
    )
    parser.add_argument(
        "--profiler-config",
        type=parse_profiler_config,
        default=None,
        help='JSON profiler config for torch/cuda profiling, e.g. \'{"profiler":"torch","torch_profiler_dir":"./perf"}\'.',
    )
    parser.add_argument(
        "--log-stats",
        action="store_true",
        help="Enable logging of diffusion pipeline stats.",
    )
    parser.add_argument(
        "--init-timeout",
        type=int,
        default=600,
        help="Timeout for initializing a single stage in seconds (default: 600s)",
    )
    parser.add_argument(
        "--stage-init-timeout",
        type=int,
        default=600,
        help="Timeout for initializing a single stage in seconds (default: 600s)",
    )
    parser.add_argument(
        "--use-system-prompt",
        type=str,
        default=None,
        choices=["None", "dynamic", "en_vanilla", "en_recaption", "en_think_recaption", "en_unified", "custom"],
        help="System prompt preset for generation. Recommended: en_unified.",
    )
    parser.add_argument(
        "--system-prompt",
        type=str,
        default=None,
        help=("Custom system prompt. Used when --use-system-prompt is custom. "),
    )
    current_omni_platform.pre_register_and_update(parser)
    from vllm_omni.engine.arg_utils import nullify_stage_engine_defaults

    nullify_stage_engine_defaults(parser)
    return parser.parse_args()


def main():
    args = parse_args()
    generator = torch.Generator(device=current_omni_platform.device_type).manual_seed(args.seed)
    use_nextstep = is_nextstep_model(args.model)

    cache_config = None
    cache_backend = args.cache_backend

    if cache_backend == "cache_dit":
        # cache-dit configuration: Hybrid DBCache + SCM + TaylorSeer
        # All parameters marked with [cache-dit only] in DiffusionCacheConfig
        cache_config = {
            # DBCache parameters [cache-dit only]
            "Fn_compute_blocks": 1,  # Optimized for single-transformer models
            "Bn_compute_blocks": 0,  # Number of backward compute blocks
            "max_warmup_steps": 4,  # Maximum warmup steps (works for few-step models)
            "residual_diff_threshold": 0.24,  # Higher threshold for more aggressive caching
            "max_continuous_cached_steps": 3,  # Limit to prevent precision degradation
            # TaylorSeer parameters [cache-dit only]
            "enable_taylorseer": False,  # Disabled by default (not suitable for few-step models)
            "taylorseer_order": 1,  # TaylorSeer polynomial order
            # SCM (Step Computation Masking) parameters [cache-dit only]
            "scm_steps_mask_policy": None,  # SCM mask policy: None (disabled), "slow", "medium", "fast", "ultra"
            "scm_steps_policy": "dynamic",  # SCM steps policy: "dynamic" or "static"
        }
    elif cache_backend == "tea_cache":
        # TeaCache configuration
        # All parameters marked with [tea_cache only] in DiffusionCacheConfig
        cache_config = {
            # TeaCache parameters [tea_cache only]
            "rel_l1_thresh": 0.2,  # Threshold for accumulated relative L1 distance
            # Note: coefficients will use model-specific defaults based on model_type
            #       (e.g., QwenImagePipeline or FluxPipeline)
        }

    profiler_enabled = args.profiler_config is not None

    # Prepare LoRA kwargs for Omni initialization
    lora_args: dict[str, Any] = {}
    if args.lora_path:
        lora_args["lora_path"] = args.lora_path
        print(f"Using LoRA from: {args.lora_path}")

    # Build quantization kwargs: use quantization_config dict when
    # ignored_layers is specified so the list flows through OmniDiffusionConfig
    quant_kwargs: dict[str, Any] = {}
    ignored_layers = [s.strip() for s in args.ignored_layers.split(",") if s.strip()] if args.ignored_layers else None
    if args.quantization == "gguf":
        if not args.gguf_model:
            raise ValueError("--gguf-model is required when --quantization gguf is set.")
        quant_kwargs["quantization_config"] = {
            "method": "gguf",
            "gguf_model": args.gguf_model,
        }
    elif args.quantization and ignored_layers:
        quant_kwargs["quantization_config"] = {
            "method": args.quantization,
            "ignored_layers": ignored_layers,
        }
    elif args.quantization:
        quant_kwargs["quantization"] = args.quantization

    omni_kwargs = {
        "model": args.model,
        "enable_layerwise_offload": args.enable_layerwise_offload,
        "vae_use_slicing": args.vae_use_slicing,
        "vae_use_tiling": args.vae_use_tiling,
        "cache_backend": args.cache_backend,
        "cache_config": cache_config,
        "enable_cache_dit_summary": args.enable_cache_dit_summary,
        "ulysses_degree": args.ulysses_degree,
        "ring_degree": args.ring_degree,
        "ulysses_mode": args.ulysses_mode,
        "cfg_parallel_size": args.cfg_parallel_size,
        "tensor_parallel_size": args.tensor_parallel_size,
        "vae_patch_parallel_size": args.vae_patch_parallel_size,
        "enable_expert_parallel": args.enable_expert_parallel,
        "enforce_eager": args.enforce_eager,
        "enable_cpu_offload": args.enable_cpu_offload,
        "mode": "text-to-image",
        "log_stats": args.log_stats,
        "enable_diffusion_pipeline_profiler": args.enable_diffusion_pipeline_profiler,
        "profiler_config": args.profiler_config,
        "init_timeout": args.init_timeout,
        "stage_init_timeout": args.stage_init_timeout,
        **lora_args,
        **quant_kwargs,
    }
    if args.stage_configs_path:
        omni_kwargs["stage_configs_path"] = args.stage_configs_path
    if use_nextstep:
        # NextStep-1.1 requires explicit pipeline class
        omni_kwargs["model_class_name"] = "NextStep11Pipeline"
    omni = Omni(**omni_kwargs)

    if profiler_enabled:
        print("[Profiler] Starting profiling...")
        omni.start_profile()

    # Time profiling for generation
    print(f"\n{'=' * 60}")
    print("Generation Configuration:")
    print(f"  Model: {args.model}")
    print(f"  Inference steps: {args.num_inference_steps}")
    print(f"  Cache backend: {cache_backend if cache_backend else 'None (no acceleration)'}")
    print(f"  Quantization: {args.quantization if args.quantization else 'None (BF16)'}")
    if ignored_layers:
        print(f"  Ignored layers: {ignored_layers}")
    print(
        f"  Parallel configuration: tensor_parallel_size={args.tensor_parallel_size}, "
        f"ulysses_degree={args.ulysses_degree}, ulysses_mode={args.ulysses_mode}, "
        f"ring_degree={args.ring_degree}, cfg_parallel_size={args.cfg_parallel_size}, "
        f"vae_patch_parallel_size={args.vae_patch_parallel_size}, "
        f"enable_expert_parallel={args.enable_expert_parallel}."
    )
    print(f"  CPU offload: {args.enable_cpu_offload}; CPU Layerwise Offload: {args.enable_layerwise_offload}")
    print(f"  Image size: {args.width}x{args.height}")
    if args.lora_path:
        print(f"  LoRA: scale={args.lora_scale}")
    if args.stage_configs_path:
        print(f"  stage-configs-path: {args.stage_configs_path}")
    print(f"{'=' * 60}\n")

    # Build LoRA request when --lora-path is set
    lora_request = None
    if args.lora_path:
        lora_request_id = stable_lora_int_id(args.lora_path)
        lora_request = LoRARequest(
            lora_name=Path(args.lora_path).stem,
            lora_int_id=lora_request_id,
            lora_path=args.lora_path,
        )

    generation_start = time.perf_counter()
    extra_args = {
        "timesteps_shift": args.timesteps_shift,
        "cfg_schedule": args.cfg_schedule,
        "use_norm": args.use_norm,
        "use_system_prompt": args.use_system_prompt,
        "system_prompt": args.system_prompt,
    }
    if lora_request:
        extra_args["lora_request"] = lora_request
        extra_args["lora_scale"] = args.lora_scale

    outputs = omni.generate(
        {
            "prompt": args.prompt,
            "negative_prompt": args.negative_prompt,
        },
        OmniDiffusionSamplingParams(
            height=args.height,
            width=args.width,
            generator=generator,
            true_cfg_scale=args.cfg_scale,
            guidance_scale=args.guidance_scale,
            guidance_scale_2=args.guidance_scale_2,
            num_inference_steps=args.num_inference_steps,
            num_outputs_per_prompt=args.num_images_per_prompt,
            extra_args=extra_args,
        ),
    )

    generation_end = time.perf_counter()
    generation_time = generation_end - generation_start

    # Print profiling results
    print(f"Total generation time: {generation_time:.4f} seconds ({generation_time * 1000:.2f} ms)")

    if profiler_enabled:
        print("\n[Profiler] Stopping profiler and collecting results...")
        profile_results = omni.stop_profile()
        if profile_results and isinstance(profile_results, dict):
            traces = profile_results.get("traces", [])
            print("\n" + "=" * 60)
            print("PROFILING RESULTS:")
            for rank, trace in enumerate(traces):
                print(f"\nRank {rank}:")
                if trace:
                    print(f"  • Trace: {trace}")
            if not traces:
                print("  No traces collected.")
            print("=" * 60)
        else:
            print("[Profiler] No valid profiling data returned.")

    # omni.generate() returns list[OmniRequestOutput]
    if not outputs or len(outputs) == 0:
        raise ValueError("No output generated from omni.generate()")
    logger.info(f"Outputs: {outputs}")

    first_output = outputs[0]
    if not hasattr(first_output, "request_output") or not first_output.request_output:
        raise ValueError("No request_output found in OmniRequestOutput")

    req_out = first_output.request_output
    if not hasattr(req_out, "images"):
        raise ValueError("Invalid request_output structure or missing 'images'.")

    images = req_out.images
    if not images:
        raise ValueError("No images found in request_output")

    output_path = Path(args.output)
    output_path.parent.mkdir(parents=True, exist_ok=True)
    suffix = output_path.suffix or ".png"
    stem = output_path.stem or "qwen_image_output"
    if len(images) <= 1:
        images[0].save(output_path)
        print(f"Saved generated image to {output_path}")
    else:
        for idx, img in enumerate(images):
            save_path = output_path.parent / f"{stem}_{idx}{suffix}"
            img.save(save_path)
            print(f"Saved generated image to {save_path}")


if __name__ == "__main__":
    main()