Skip to content

vllm_omni.diffusion.worker

Worker classes for diffusion models.

Modules:

Name Description
diffusion_model_runner

Diffusion Model Runner for vLLM-Omni.

diffusion_worker

Diffusion Worker for vLLM-Omni.

input_batch

Diffusion input-batch structures following the MRV2-style vLLM layout.

utils

Per-request mutable state for step-wise diffusion execution.

DiffusionModelRunner

Bases: OmniConnectorModelRunnerMixin

Model runner that handles model loading and execution for diffusion models.

This class follows the AR pattern where the Runner handles all model-related operations including loading, compilation, offloading, caching, and execution. The Worker only handles infrastructure (device, distributed env).

cache_backend instance-attribute

cache_backend = None

device instance-attribute

device = device

kv_transfer_manager instance-attribute

kv_transfer_manager = OmniKVTransferManager.from_od_config(
    od_config
)

od_config instance-attribute

od_config = od_config

offload_backend instance-attribute

offload_backend = None

pipeline instance-attribute

pipeline = None

prompt_embed_cache instance-attribute

prompt_embed_cache = None

state_cache instance-attribute

state_cache: dict[str, DiffusionRequestState] = {}

vllm_config instance-attribute

vllm_config = vllm_config

clear_prompt_embed_cache

clear_prompt_embed_cache() -> None

Evict all cached text-encoder outputs (e.g. between training epochs).

execute_model

execute_model(req: OmniDiffusionRequest) -> DiffusionOutput

Execute a forward pass for the given requests.

Parameters:

Name Type Description Default
req OmniDiffusionRequest

A diffusion request containing a list of prompts to process.

required

Returns:

Type Description
DiffusionOutput

DiffusionOutput with generated results.

Note

We use torch.no_grad() for HSDP because HSDP2's fully_shard requires access to tensor version counters in pre_forward hooks, which inference tensors do not track. For non-HSDP inference, we use torch.inference_mode() for better performance.

execute_stepwise

execute_stepwise(
    scheduler_output: DiffusionSchedulerOutput,
) -> BatchRunnerOutput

Execute one step for one scheduled request and return runner output.

get_prompt_embed_cache_stats

get_prompt_embed_cache_stats() -> dict | None

Return hit/miss statistics for the prompt-embedding cache, if enabled.

load_model

load_model(
    memory_pool_context_fn: callable | None = None,
    load_format: str = "default",
    custom_pipeline_name: str | None = None,
) -> None

Load the diffusion model, apply compilation and offloading.

Parameters:

Name Type Description Default
memory_pool_context_fn callable | None

Optional function that returns a context manager for memory pool allocation (used for sleep mode).

None
load_format str

Format for loading model weights. Supported formats: - "default" (default): Automatically detect and use the default format based on configuration - "custom_pipeline": Init model from a custom pipeline class specified by custom_pipeline_name - "dummy": Skip actual weight loading, useful for testing and custom pipelines that don't require default weights.

'default'
custom_pipeline_name str | None

Optional custom pipeline class name to use.

None

load_weights

load_weights(
    weights: Iterable[tuple[str, Tensor]],
) -> set[str]

Load weights into the pipeline.

supports_step_mode

supports_step_mode() -> bool

Return whether current pipeline supports step execution.

DiffusionWorker

A worker that manages GPU infrastructure and delegates to the model runner.

This class handles infrastructure initialization only: - Device setup (CUDA device selection) - Distributed environment (NCCL, model parallel) - Memory management (sleep/wake)

All model-related operations (loading, compilation, execution) are delegated to DiffusionModelRunner.

device instance-attribute

device: device | None = None

local_rank instance-attribute

local_rank = local_rank

lora_manager instance-attribute

lora_manager: DiffusionLoRAManager | None = None

model_runner instance-attribute

model_runner: DiffusionModelRunner | None = (
    model_runner_cls(
        vllm_config=self.vllm_config,
        od_config=self.od_config,
        device=self.device,
    )
)

od_config instance-attribute

od_config = od_config

profiler instance-attribute

profiler: WorkerProfiler | None = self._create_profiler()

rank instance-attribute

rank = rank

stage_id instance-attribute

stage_id = getattr(od_config, 'stage_id', 0)

vllm_config instance-attribute

vllm_config: VllmConfig | None = None

add_lora

add_lora(lora_request: LoRARequest) -> bool

execute_model

execute_model(
    req: OmniDiffusionRequest,
    od_config: OmniDiffusionConfig,
) -> DiffusionOutput

Execute a forward pass by delegating to the model runner.

execute_stepwise

execute_stepwise(
    scheduler_output: DiffusionSchedulerOutput,
) -> BaseRunnerOutput

Execute one diffusion step by delegating to the model runner.

generate

generate(request: OmniDiffusionRequest) -> DiffusionOutput

Generate output for the given requests.

handle_sleep_task

handle_sleep_task(task: OmniSleepTask) -> OmniACK

handle_wake_task

handle_wake_task(task: OmniWakeTask) -> OmniACK

init_device

init_device() -> None

Initialize the device and distributed environment.

init_lora_manager

init_lora_manager() -> None

Initialize the LoRA manager for this worker.

list_loras

list_loras() -> list[int]

load_model

load_model(
    load_format: str = "default",
    custom_pipeline_name: str | None = None,
    **kwargs,
) -> None

Load the diffusion model using DiffusionModelRunner.

load_weights

load_weights(weights) -> set[str]

Load weights by delegating to the model runner.

pin_lora

pin_lora(adapter_id: int) -> bool

profile

profile(
    is_start: bool = True, profile_prefix: str | None = None
) -> None

Start or stop profiling for this GPU worker.

Parameters:

Name Type Description Default
is_start bool

True to start profiling, False to stop.

True
profile_prefix str | None

Optional prefix for trace filename.

None

remove_lora

remove_lora(adapter_id: int) -> bool

shutdown

shutdown() -> None

Shutdown the worker and cleanup distributed environment.

sleep

sleep(level: int = 1) -> bool

Put the worker to sleep, offloading model weights.

Parameters:

Name Type Description Default
level int

Sleep level. Level 1 offloads weights, level 2 also saves buffers.

1

wake_up

wake_up(tags: list[str] | None = None) -> bool

Wake up the worker from sleep mode.

Re-activates the memory allocator for the specified tags and restores model buffers from CPU back to GPU if they were saved during Level 2 sleep.

Parameters:

Name Type Description Default
tags list[str] | None

List of memory pool tags to re-activate (e.g., ["weights"] to match Level 1 sleep). If None, all pools are re-activated.

None

WorkerProc

Wrapper that runs one Worker in a separate process.

context instance-attribute

context = zmq.Context(io_threads=2)

gpu_id instance-attribute

gpu_id = gpu_id

mq instance-attribute

mq = MessageQueue.create_from_handle(
    broadcast_handle, gpu_id
)

od_config instance-attribute

od_config = od_config

result_mq instance-attribute

result_mq = None

result_mq_handle instance-attribute

result_mq_handle = None

wake_event instance-attribute

wake_event = wake_event

worker instance-attribute

worker = self._create_worker(
    gpu_id,
    od_config,
    worker_extension_cls,
    custom_pipeline_args,
)

execute_rpc

execute_rpc(
    rpc_request: dict,
) -> tuple[object | None, bool]

Execute an RPC request and indicate whether to reply.

recv_message

recv_message()

Receive messages from broadcast queue.

return_result

return_result(output: Any)

Reply to client, only on rank 0.

worker_busy_loop

worker_busy_loop() -> None

Main busy loop for Multiprocessing Workers.

worker_main staticmethod

worker_main(
    rank: int,
    od_config: OmniDiffusionConfig,
    pipe_writer: Connection,
    broadcast_handle,
    wake_event: Event,
    worker_extension_cls: str | None = None,
    custom_pipeline_args: dict[str, Any] | None = None,
) -> None

Worker initialization and execution loops.