Skip to content

vllm_omni.engine.stage_init_utils

Stage initialization helpers for vLLM-Omni multi-stage runtime.

Extracts orchestration-level init logic (config extraction, plugin loading, multiprocessing setup, device mapping, device locking, engine args building) out of StageEngineCoreClient into reusable functions.

logger module-attribute

logger = init_logger(__name__)

LogicalStageInitPlan dataclass

Startup plan for one logical stage.

configured_stage_id instance-attribute

configured_stage_id: int

replicas instance-attribute

replicas: list[ReplicaInitPlan]

stage_idx instance-attribute

stage_idx: int

ReplicaInitPlan dataclass

One concrete replica startup unit within a logical stage.

executor_class class-attribute instance-attribute

executor_class: type | None = None

launch_mode instance-attribute

launch_mode: str

metadata instance-attribute

metadata: Any

num_replicas instance-attribute

num_replicas: int

omni_kv_connector instance-attribute

omni_kv_connector: tuple[
    dict[str, Any] | None, str | None, str | None
]

replica_id instance-attribute

replica_id: int

stage_cfg instance-attribute

stage_cfg: Any

stage_connector_spec instance-attribute

stage_connector_spec: dict[str, Any]

stage_vllm_config class-attribute instance-attribute

stage_vllm_config: Any | None = None

StageMetadata dataclass

Lightweight stage attributes extracted from stage_config.

cfg_kv_collect_func class-attribute instance-attribute

cfg_kv_collect_func: Callable | None = None

custom_process_input_func instance-attribute

custom_process_input_func: Callable | None

default_sampling_params instance-attribute

default_sampling_params: OmniSamplingParams

engine_input_source instance-attribute

engine_input_source: list[int]

engine_output_type instance-attribute

engine_output_type: str | None

final_output instance-attribute

final_output: bool

final_output_type instance-attribute

final_output_type: str | None

is_comprehension instance-attribute

is_comprehension: bool

model_stage instance-attribute

model_stage: str | None

prompt_expand_func class-attribute instance-attribute

prompt_expand_func: Callable | None = None

replica_id class-attribute instance-attribute

replica_id: int = 0

requires_multimodal_data instance-attribute

requires_multimodal_data: bool

runtime_cfg instance-attribute

runtime_cfg: Any

stage_id instance-attribute

stage_id: int

stage_type instance-attribute

stage_type: Literal['llm', 'diffusion']

StageRemoteFactoryContext dataclass

Per-stage context cached by AsyncOmniEngine for dynamic replica attach.

Populated once during _bootstrap_orchestrator from the per-stage init plans. _build_remote_replica consumes it to construct the right head-side stage client when a headless replica registers.

base_metadata instance-attribute

base_metadata: Any

diffusion_batch_size class-attribute instance-attribute

diffusion_batch_size: int = 1

executor_class class-attribute instance-attribute

executor_class: type | None = None

stage_cfg instance-attribute

stage_cfg: Any

stage_id instance-attribute

stage_id: int

stage_type instance-attribute

stage_type: str

vllm_config class-attribute instance-attribute

vllm_config: Any | None = None

acquire_device_locks

acquire_device_locks(
    stage_id: int,
    engine_args_dict: dict[str, Any],
    stage_init_timeout: int,
) -> list[int]

Acquire exclusive file locks on devices needed by this stage.

Returns list of lock file descriptors that must be released after init.

acquire_diffusion_device_locks

acquire_diffusion_device_locks(
    stage_id: int, od_config: Any, stage_init_timeout: int
) -> list[int]

Acquire init locks for the GPU set used by a diffusion stage.

Diffusion stages express their device count via OmniDiffusionConfig's parallel_config.world_size rather than the LLM-style tensor_parallel_size knob, so adapt to the shape that acquire_device_locks understands.

apply_cli_tokenizer

apply_cli_tokenizer(
    engine_args: dict[str, Any],
    *,
    cli_tokenizer: str | None,
    stage_defines_tokenizer: bool,
) -> None

Forward CLI tokenizer unless the stage config defines its own.

build_diffusion_config

build_diffusion_config(
    model: str, stage_cfg: Any, metadata: StageMetadata
) -> Any

Build diffusion config for a stage.

build_engine_args_dict

build_engine_args_dict(
    stage_config: Any,
    model: str,
    stage_connector_spec: dict[str, Any] | None = None,
    cli_tokenizer: str | None = None,
) -> dict[str, Any]

Build the normalized engine args dict for one stage.

build_llm_stage_output_processor

build_llm_stage_output_processor(
    plan: LogicalStageInitPlan,
    stage_vllm_config: Any,
    log_stats: bool = False,
) -> Any | None

Build one output processor per logical LLM stage.

log_stats controls whether the processor populates per-request IterationStats (consumed by the Prometheus wrap). Default False matches the upstream MultimodalOutputProcessor default and respects the --log-stats CLI flag plumbed through AsyncOmniEngine.

build_stage0_input_processor

build_stage0_input_processor(
    stage_vllm_config: Any,
) -> InputProcessor

Build the shared stage-0 input processor.

build_vllm_config

build_vllm_config(
    stage_config: Any,
    model: str,
    stage_connector_spec: dict[str, Any] | None = None,
    engine_args_dict: dict[str, Any] | None = None,
    headless: bool = False,
) -> tuple[Any, type]

Build engine args, then create VllmConfig and executor_class.

Returns:

Type Description
tuple[Any, type]

(vllm_config, executor_class)

capture_stage_factory_contexts

capture_stage_factory_contexts(
    stage_plans: Sequence[LogicalStageInitPlan],
    diffusion_batch_size: int,
) -> dict[int, StageRemoteFactoryContext]

Snapshot per-stage construction context for dynamic replica attach.

Called once after _initialize_stages finishes. The captured context holds everything _build_remote_replica needs to build a fresh head-side client when a new headless replica registers (vllm_config / executor_class for LLM, batch_size for diffusion, plus the base stage metadata).

Per-replica fields like replica_id are filled in at build time, not at capture time.

compute_replica_layout

compute_replica_layout(
    stage_configs: Sequence[Any],
    *,
    allow_zero: bool = False,
) -> tuple[list[int], dict[int, list[str]]]

Compute per-stage replica counts and device assignments.

Parameters:

Name Type Description Default
stage_configs Sequence[Any]

per-stage config objects with a runtime sub-config exposing num_replicas and devices.

required
allow_zero bool

when True, num_replicas == 0 is honored (used by single-stage / head-distributed mode for non-self stages that will be filled dynamically by remote registrations); when False (default), the count is clamped to at least 1.

False

Returns:

Name Type Description
replicas_per_stage list[int]

num_replicas per logical stage.

replica_devices_map dict[int, list[str]]

stage_idx -> per-replica device strings (only for stages with num_replicas > 1).

extract_stage_metadata

extract_stage_metadata(stage_config: Any) -> StageMetadata

Pure data extraction from a stage_config object.

get_stage_connector_spec

get_stage_connector_spec(
    omni_transfer_config: Any,
    stage_id: int,
    async_chunk: bool,
) -> dict[str, Any]

Return the first connector spec for the stage when async chunking is enabled.

get_stage_devices_per_replica

get_stage_devices_per_replica(stage_cfg: Any) -> int

Return the number of devices consumed by one replica of stage_cfg.

get_stage_tp_size

get_stage_tp_size(stage_cfg: Any) -> int

Extract tensor_parallel_size from a stage config object.

initialize_diffusion_stage

initialize_diffusion_stage(
    stage_id: int,
    model: str,
    stage_cfg: Any,
    metadata: StageMetadata,
    stage_init_timeout: int,
    batch_size: int = 1,
    use_inline: bool = False,
) -> Any

Build a diffusion stage client.

Parameters:

Name Type Description Default
model str

Model name or path.

required
stage_cfg Any

Stage configuration.

required
metadata StageMetadata

Extracted stage metadata.

required
stage_init_timeout int

Timeout in seconds for stage initialization handshake

required
batch_size int

Maximum number of requests to batch together in the diffusion engine. Passed through to StageDiffusionClient and ultimately to AsyncOmni.

1
use_inline bool

If True, uses the inline diffusion client instead of subprocess.

False

inject_kv_stage_info

inject_kv_stage_info(
    stage_cfg: Any,
    stage_id: int,
    stage_configs: Sequence[Any] | None = None,
) -> None

Inject stage_id, engine_input_source, and inferred TP topology into omni_kv_config.

When stage_configs is provided, also infers from_tp/to_tp for heterogeneous TP topologies so the KV transfer manager can compute rank mappings automatically.

load_omni_transfer_config_for_model

load_omni_transfer_config_for_model(
    model: str, config_path: str | None
) -> Any

Load omni transfer config from an explicit path or resolved model config.

Resolves base_config inheritance (CI overlay → base deploy YAML) so that connectors defined in the base config are visible to the transfer config parser.

patch_generation_config_if_needed

patch_generation_config_if_needed(
    model_config: Any,
) -> None

Guard InputProcessor init for models whose config lacks model_type.

prepare_engine_environment

prepare_engine_environment() -> None

One-time global setup: load plugins, set multiprocessing spawn method.

release_device_locks

release_device_locks(lock_fds: list[int]) -> None

Release file locks acquired by acquire_device_locks.

resolve_worker_cls

resolve_worker_cls(engine_args: dict[str, Any]) -> None

Resolve worker_cls from worker_type for non-diffusion stages.

set_death_signal

set_death_signal(sig: int) -> None

Best-effort parent-death signal for Linux subprocesses.

setup_stage_devices

setup_stage_devices(
    stage_id: int, runtime_cfg: Any
) -> None

Device mapping via set_stage_devices for a single stage.

split_devices_for_replicas

split_devices_for_replicas(
    devices_str: str | None,
    num_replicas: int,
    tp_size: int,
    stage_id: int,
) -> list[str]

Split a devices string into per-replica subsets.

When num_replicas is 1, returns [devices_str] unchanged. Otherwise, two YAML shapes are accepted:

  1. Legacy / pool modelen(devices) == num_replicas * tp_size: the string enumerates the full per-stage pool. Each replica gets tp_size consecutive entries. The values are logical indices into the launcher's CUDA_VISIBLE_DEVICES.

split_devices_for_replicas("1,2,3,4", 2, 2, 1) → ["1,2", "3,4"]

  1. Template modelen(devices) == tp_size: the YAML declares a single per-replica template (the same shape one replica would use), and is dp-independent. Each replica r gets the offsets [r*tp_size + a for a in template] of the launcher's CUDA_VISIBLE_DEVICES. The template's entries must lie in [0, tp_size).

split_devices_for_replicas("0,1", 2, 2, 1) → ["0,1", "2,3"] split_devices_for_replicas("0,1", 4, 2, 1) → ["0,1", "2,3", "4,5", "6,7"]

This lets the same devices: "0,1" YAML work for any --omni-dp-size-local: the launcher's CVD scales, the YAML does not.

Any other length raises ValueError (the two modes are length-disjoint for num_replicas > 1).

stage_runtime_env

stage_runtime_env(
    stage_id: int, runtime_cfg: Any
) -> Generator[None, None, None]

Apply per-stage runtime.env for the duration of the context.

stage_runtime_setup

stage_runtime_setup(
    stage_id: int, runtime_cfg: Any
) -> Generator[None, None, None]

Apply per-stage runtime.env and runtime.devices for the context.

Restores runtime.env on exit. Device visibility restore remains the caller's responsibility (e.g. AsyncOmniEngine saves/restores the platform device-control env var around this block).

terminate_alive_proc

terminate_alive_proc(proc, timeout=5)