vllm_omni.diffusion.worker.diffusion_model_runner ¶
Diffusion Model Runner for vLLM-Omni.
Handles model loading, compilation, caching, and execution of diffusion model forward passes. This follows the AR pattern where the Runner handles all model-related operations.
DiffusionModelRunner ¶
Bases: OmniConnectorModelRunnerMixin
Model runner that handles model loading and execution for diffusion models.
This class follows the AR pattern where the Runner handles all model-related operations including loading, compilation, offloading, caching, and execution. The Worker only handles infrastructure (device, distributed env).
clear_prompt_embed_cache ¶
Evict all cached text-encoder outputs (e.g. between training epochs).
execute_model ¶
execute_model(req: OmniDiffusionRequest) -> DiffusionOutput
Execute a forward pass for the given requests.
Parameters:
| Name | Type | Description | Default |
|---|---|---|---|
req | OmniDiffusionRequest | A diffusion request containing a list of prompts to process. | required |
Returns:
| Type | Description |
|---|---|
DiffusionOutput | DiffusionOutput with generated results. |
Note
We use torch.no_grad() for HSDP because HSDP2's fully_shard requires access to tensor version counters in pre_forward hooks, which inference tensors do not track. For non-HSDP inference, we use torch.inference_mode() for better performance.
execute_stepwise ¶
execute_stepwise(
scheduler_output: DiffusionSchedulerOutput,
) -> BatchRunnerOutput
Execute one step for one scheduled request and return runner output.
get_prompt_embed_cache_stats ¶
get_prompt_embed_cache_stats() -> dict | None
Return hit/miss statistics for the prompt-embedding cache, if enabled.
load_model ¶
load_model(
memory_pool_context_fn: callable | None = None,
load_format: str | None = None,
custom_pipeline_name: str | type | None = None,
) -> None
Load the diffusion model, apply compilation and offloading.
Parameters:
| Name | Type | Description | Default |
|---|---|---|---|
memory_pool_context_fn | callable | None | Optional function that returns a context manager for memory pool allocation (used for sleep mode). | None |
load_format | str | None | Format for loading model weights. Supported formats: - "default" (default): Automatically detect and use the default format based on configuration - "custom_pipeline": Init model from a custom pipeline class specified by | None |
custom_pipeline_name | str | type | None | Optional custom pipeline class name to use. | None |
load_weights ¶
Load weights into the pipeline.