Skip to content

vllm_omni.diffusion.worker.diffusion_model_runner

Diffusion Model Runner for vLLM-Omni.

Handles model loading, compilation, caching, and execution of diffusion model forward passes. This follows the AR pattern where the Runner handles all model-related operations.

logger module-attribute

logger = init_logger(__name__)

DiffusionModelRunner

Bases: OmniConnectorModelRunnerMixin

Model runner that handles model loading and execution for diffusion models.

This class follows the AR pattern where the Runner handles all model-related operations including loading, compilation, offloading, caching, and execution. The Worker only handles infrastructure (device, distributed env).

cache_backend instance-attribute

cache_backend = None

device instance-attribute

device = device

kv_transfer_manager instance-attribute

kv_transfer_manager = from_od_config(od_config)

od_config instance-attribute

od_config = od_config

offload_backend instance-attribute

offload_backend = None

pipeline instance-attribute

pipeline = None

prompt_embed_cache instance-attribute

prompt_embed_cache = None

state_cache instance-attribute

state_cache: dict[str, DiffusionRequestState] = {}

vllm_config instance-attribute

vllm_config = vllm_config

clear_prompt_embed_cache

clear_prompt_embed_cache() -> None

Evict all cached text-encoder outputs (e.g. between training epochs).

execute_model

execute_model(req: OmniDiffusionRequest) -> DiffusionOutput

Execute a forward pass for the given requests.

Parameters:

Name Type Description Default
req OmniDiffusionRequest

A diffusion request containing a list of prompts to process.

required

Returns:

Type Description
DiffusionOutput

DiffusionOutput with generated results.

Note

We use torch.no_grad() for HSDP because HSDP2's fully_shard requires access to tensor version counters in pre_forward hooks, which inference tensors do not track. For non-HSDP inference, we use torch.inference_mode() for better performance.

execute_stepwise

execute_stepwise(
    scheduler_output: DiffusionSchedulerOutput,
) -> BatchRunnerOutput

Execute one step for one scheduled request and return runner output.

get_prompt_embed_cache_stats

get_prompt_embed_cache_stats() -> dict | None

Return hit/miss statistics for the prompt-embedding cache, if enabled.

load_model

load_model(
    memory_pool_context_fn: callable | None = None,
    load_format: str | None = None,
    custom_pipeline_name: str | type | None = None,
) -> None

Load the diffusion model, apply compilation and offloading.

Parameters:

Name Type Description Default
memory_pool_context_fn callable | None

Optional function that returns a context manager for memory pool allocation (used for sleep mode).

None
load_format str | None

Format for loading model weights. Supported formats: - "default" (default): Automatically detect and use the default format based on configuration - "custom_pipeline": Init model from a custom pipeline class specified by custom_pipeline_name - "dummy": Skip actual weight loading, useful for testing and custom pipelines that don't require default weights.

None
custom_pipeline_name str | type | None

Optional custom pipeline class name to use.

None

load_weights

load_weights(
    weights: Iterable[tuple[str, Tensor]],
) -> set[str]

Load weights into the pipeline.

supports_step_mode

supports_step_mode() -> bool

Return whether current pipeline supports step execution.