Skip to content

vllm_gaudi.v1.engine.multi_model_async_llm

Multi-model support for AsyncLLM on Gaudi platform.

Simplified version that does not use complex mode/pause handling and focuses on core functionality: initialize -> generate -> switch -> generate.

logger module-attribute

logger = init_logger(__name__)

MultiModelAsyncLLM

Wrapper around AsyncLLM for dynamic model switching.

Usage flow: 1. Create with model configs: MultiModelAsyncLLM({"model_a": config_a, "model_b": config_b}) 2. Initialize with first model: await manager.initialize("model_a") 3. Generate: async for output in manager.generate(prompt, params, request_id): ... 4. Switch models: await manager.switch_model("model_b") 5. Generate with new model 6. Cleanup: manager.shutdown()

Example

from vllm.engine.arg_utils import AsyncEngineArgs from vllm_gaudi.v1.engine.multi_model_async_llm import MultiModelAsyncLLM

models = { ... "model_a": AsyncEngineArgs(model="meta-llama/Llama-3.1-8B-Instruct"), ... "model_b": AsyncEngineArgs(model="Qwen/Qwen3-0.6B"), ... } manager = MultiModelAsyncLLM(models) await manager.initialize("model_a") async for output in manager.generate("Hello", SamplingParams(max_tokens=20), "req-1"): ... print(output.outputs[0].text) await manager.switch_model("model_b") manager.shutdown()

Source code in vllm_gaudi/v1/engine/multi_model_async_llm.py
 33
 34
 35
 36
 37
 38
 39
 40
 41
 42
 43
 44
 45
 46
 47
 48
 49
 50
 51
 52
 53
 54
 55
 56
 57
 58
 59
 60
 61
 62
 63
 64
 65
 66
 67
 68
 69
 70
 71
 72
 73
 74
 75
 76
 77
 78
 79
 80
 81
 82
 83
 84
 85
 86
 87
 88
 89
 90
 91
 92
 93
 94
 95
 96
 97
 98
 99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
320
321
322
323
324
325
326
327
328
329
330
331
332
333
334
335
336
337
338
339
340
341
342
343
344
345
346
347
348
349
350
351
352
353
354
355
356
357
358
359
360
361
362
363
364
365
366
367
368
369
370
371
372
373
374
375
376
377
378
379
380
381
382
383
384
385
386
387
388
389
390
391
392
393
394
395
396
397
398
399
400
401
402
403
404
405
406
407
408
409
410
411
412
413
414
415
416
417
418
419
420
421
422
423
424
425
426
427
428
429
430
431
432
433
434
435
436
437
438
439
440
441
442
443
444
445
446
447
448
449
450
451
452
453
454
455
456
457
458
class MultiModelAsyncLLM:
    """
    Wrapper around AsyncLLM for dynamic model switching.

    Usage flow:
    1. Create with model configs: MultiModelAsyncLLM({"model_a": config_a, "model_b": config_b})
    2. Initialize with first model: await manager.initialize("model_a")
    3. Generate: async for output in manager.generate(prompt, params, request_id): ...
    4. Switch models: await manager.switch_model("model_b")
    5. Generate with new model
    6. Cleanup: manager.shutdown()

    Example:
        >>> from vllm.engine.arg_utils import AsyncEngineArgs
        >>> from vllm_gaudi.v1.engine.multi_model_async_llm import MultiModelAsyncLLM
        >>>
        >>> models = {
        ...     "model_a": AsyncEngineArgs(model="meta-llama/Llama-3.1-8B-Instruct"),
        ...     "model_b": AsyncEngineArgs(model="Qwen/Qwen3-0.6B"),
        ... }
        >>> manager = MultiModelAsyncLLM(models)
        >>> await manager.initialize("model_a")
        >>> async for output in manager.generate("Hello", SamplingParams(max_tokens=20), "req-1"):
        ...     print(output.outputs[0].text)
        >>> await manager.switch_model("model_b")
        >>> manager.shutdown()
    """

    def __init__(
        self,
        model_configs: dict[str, AsyncEngineArgs],
        usage_context: UsageContext = UsageContext.ENGINE_CONTEXT,
        disable_log_stats: bool = False,
        enable_log_requests: bool = False,
        model_quant_configs: dict[str, str | None] | None = None,
    ):
        """
        Initialize multi-model manager.

        Args:
            model_configs: Dict mapping model names to AsyncEngineArgs
            usage_context: Engine usage context
            disable_log_stats: Disable stats logging
            enable_log_requests: Enable request logging
            model_quant_configs: Optional dict mapping model names to their
                QUANT_CONFIG path (INC FP8 calibration JSON)
        """
        install_engine_core_patch()

        self._engine: AsyncLLM | None = None
        self._sleeping: dict[str, bool] = {}
        self._current_model_name: str | None = None
        self._vllm_configs: dict[str, VllmConfig] = {}
        self._switching_lock = asyncio.Lock()

        if not model_configs:
            raise ValueError("model_configs cannot be empty")

        self.model_configs = model_configs
        self.usage_context = usage_context
        self.disable_log_stats = disable_log_stats
        self.enable_log_requests = enable_log_requests
        self.model_quant_configs: dict[str, str | None] = model_quant_configs or {}

        # Pre-create VllmConfig for each model
        logger.info("Creating configs for %s models", len(model_configs))
        for name, args in model_configs.items():
            self._vllm_configs[name] = args.create_engine_config(usage_context)
            logger.info("  %s: %s", name, self._vllm_configs[name].model_config.model)

    def _apply_quant_config_env(self, model_name: str) -> None:
        """Set or unset QUANT_CONFIG in the current process for *model_name*.

        Call it before workers are spawned during ``initialize()``,
        so child processes inherit the correct quantization calibration file for the 
        selected model. Models switches update worker state through the reconfiguration path.
        """
        if model_name in self.model_quant_configs:
            quant_config_path = self.model_quant_configs[model_name]
            if quant_config_path is not None:
                os.environ["QUANT_CONFIG"] = quant_config_path
                logger.info("[quant_config] QUANT_CONFIG=%s (model=%s)", quant_config_path, model_name)
            else:
                os.environ.pop("QUANT_CONFIG", None)
                logger.info("[quant_config] QUANT_CONFIG unset (model=%s)", model_name)
        else:
            logger.info("[quant_config] QUANT_CONFIG preserved from environment (model=%s)", model_name)

    @property
    def current_model(self) -> str | None:
        """Return currently loaded model name."""
        return self._current_model_name

    @property
    def available_models(self) -> list[str]:
        """Return list of available model names."""
        return list(self.model_configs.keys())

    def get_vllm_config(self, model_name: str) -> VllmConfig:
        """Get VllmConfig for a model."""
        if model_name not in self._vllm_configs:
            raise ValueError(f"Model '{model_name}' not found. Available: {list(self.model_configs.keys())}")
        return self._vllm_configs[model_name]

    def get_all_vllm_configs(self) -> dict[str, VllmConfig]:
        """
        Get all vllm_configs for model registry building.

        Returns a shallow copy to prevent external modification.

        Returns:
            Dictionary mapping model names to their VllmConfig objects
        """
        return self._vllm_configs.copy()

    @property
    def engine(self) -> AsyncLLM:
        """Return underlying AsyncLLM engine."""
        if self._engine is None:
            raise RuntimeError("Engine not initialized. Call initialize() first.")
        return self._engine

    async def _refresh_engine_frontend_config(self, model_name: str) -> None:
        """Refresh AsyncLLM frontend state to target model config.

        Engine core reloads model weights/config in-place, but AsyncLLM frontend
        keeps its own ``model_config``, renderer, and processors used by API
        request validation/tokenization.  Keep these aligned with the switched
        model, then restart the background output handler so it picks up the
        new ``output_processor`` / ``renderer``.
        """
        if self._engine is None:
            raise RuntimeError("Engine not initialized. Call initialize() first.")

        target_config = self._vllm_configs[model_name]
        engine = self._engine

        # --- 1. Cancel the old output_handler before rebuilding processors ---
        old_task = getattr(engine, "output_handler", None)
        if old_task is not None and not old_task.done():
            old_task.cancel()
            with contextlib.suppress(asyncio.CancelledError):
                await old_task
        engine.output_handler = None

        # --- 2. Rebuild config / renderer / processors ---
        engine.vllm_config = target_config
        engine.model_config = target_config.model_config
        engine.observability_config = target_config.observability_config

        if renderer := getattr(engine, "renderer", None):
            with contextlib.suppress(Exception):
                renderer.shutdown()

        engine.renderer = renderer = renderer_from_config(target_config)
        engine.io_processor = get_io_processor(
            target_config,
            renderer,
            target_config.model_config.io_processor_plugin,
        )
        engine.input_processor = InputProcessor(target_config, renderer)
        engine.output_processor = OutputProcessor(
            renderer.tokenizer,
            log_stats=engine.log_stats,
            stream_interval=target_config.scheduler_config.stream_interval,
            tracing_enabled=target_config.observability_config.otlp_traces_endpoint is not None,
        )

        # --- 3. Restart the output handler with the new processors ---
        engine._run_output_handler()

    async def initialize(self, model_name: str) -> None:
        """
        Initialize engine with a model.

        Args:
            model_name: Model to load (must be in model_configs)

        Raises:
            ValueError: If model_name not found
            RuntimeError: If already initialized
        """
        if model_name not in self.model_configs:
            raise ValueError(f"Model '{model_name}' not found. Available: {list(self.model_configs.keys())}")

        if self._engine is not None:
            raise RuntimeError("Engine already initialized. Use switch_model() instead.")
        logger.info("Initializing engine with: %s", model_name)
        self._apply_quant_config_env(model_name)
        args = self.model_configs[model_name]
        args.disable_log_stats = self.disable_log_stats
        args.enable_log_requests = self.enable_log_requests

        self._engine = AsyncLLM.from_engine_args(
            args,
            start_engine_loop=True,
            usage_context=self.usage_context,
        )
        self._sleeping[model_name] = False
        self._current_model_name = model_name

        logger.info("Engine initialized with: %s", self._vllm_configs[model_name].model_config.model)

    async def switch_model(
        self,
        model_name: str,
        drain_timeout: int = 60,
    ) -> dict[str, float | bool | None]:
        """
        Switch to a different model with error recovery

        Steps:
        1. Drain pending requests (with timeout)
        2. Sleep current model (free KV cache + weights)
        3. Unload current model weights
        4. Reload new model on the same engine
        5. Reinitialize KV cache for new model

        If any step fails, attempts to wake up engine to restore state.

        Args:
            model_name: Target model name
            drain_timeout: Seconds to wait for requests to drain

        Raises:
            ValueError: If model not found
            RuntimeError: If engine not initialized or switch fails

        """
        async with self._switching_lock:
            switch_start = time.perf_counter()
            drain_s = 0.0
            reconfigure_s = 0.0

            if self._engine is None:
                raise RuntimeError("Engine not initialized. Call initialize() first.")

            if model_name not in self.model_configs:
                raise ValueError(f"Model '{model_name}' not found. Available: {list(self.model_configs.keys())}")

            if model_name == self._current_model_name:
                logger.info("Model '%s' already loaded.", model_name)
                return {
                    "switched": False,
                    "drain_s": 0.0,
                    "reconfigure_s": 0.0,
                    "switch_s": 0.0,
                }

            new_model = self._vllm_configs[model_name].model_config.model

            logger.info("Switching from %s to %s", self._current_model_name, model_name)

            try:
                # Step 1: Drain pending requests
                logger.info("Draining pending requests...")
                drain_start = time.perf_counter()
                try:
                    await asyncio.wait_for(
                        self._engine.wait_for_requests_to_drain(drain_timeout),
                        timeout=drain_timeout + 5,
                    )
                except asyncio.TimeoutError:
                    logger.warning(
                        "Drain timeout (%ss) exceeded; in-flight requests will be aborted "
                        "by the reconfigure step (pause_scheduler mode='abort'). "
                        "Clients whose requests are aborted will receive errors.",
                        drain_timeout,
                    )
                finally:
                    drain_s = time.perf_counter() - drain_start

                # Step 2: Reconfigure engine core and scheduler in-process
                logger.info("Reconfiguring engine for: %s", model_name)
                serialized_config = cloudpickle.dumps(self._vllm_configs[model_name])
                reconfigure_start = time.perf_counter()
                if model_name in self.model_quant_configs:
                    quant_config_path = self.model_quant_configs[model_name]
                    reconfigure_result = await self._engine.engine_core.call_utility_async(
                        "gaudi_reconfigure_engine",
                        serialized_config,
                        quant_config_path,
                    )
                else:
                    reconfigure_result = await self._engine.engine_core.call_utility_async(
                        "gaudi_reconfigure_engine",
                        serialized_config,
                    )
                reconfigure_s = time.perf_counter() - reconfigure_start
                logger.info(
                    "[gaudi_reconfigure] caller complete: to=%s elapsed=%.2fs",
                    model_name,
                    reconfigure_s,
                )
                previous_model_name = self._current_model_name
                assert previous_model_name is not None
                await self._refresh_engine_frontend_config(model_name)
                self._sleeping[previous_model_name] = True
                self._sleeping[model_name] = False
                logger.info("Model sleep state: %s=sleeping", previous_model_name)
                logger.info("Model sleep state: %s=awake", model_name)
                self._current_model_name = model_name
                logger.info("Successfully switched from %s to: %s", previous_model_name, new_model)

                result: dict[str, float | bool | None] = {
                    "switched": True,
                    "drain_s": drain_s,
                    "reconfigure_s": reconfigure_s,
                    "switch_s": time.perf_counter() - switch_start,
                }
                if isinstance(reconfigure_result, dict):
                    result.update(reconfigure_result)
                return result

            except Exception as e:
                logger.error("Model switch failed during %s: %s. Attempting to restore engine state...",
                             e.__class__.__name__, e)
                # Attempt recovery: wake up weights/KV cache if stuck in sleep, then
                # resume the scheduler (which may have been paused by gaudi_reconfigure_engine).
                try:
                    logger.info("Attempting to wake up engine for recovery...")
                    await self._engine.wake_up(tags=["weights", "kv_cache"])
                    if self._current_model_name is not None:
                        self._sleeping[self._current_model_name] = False
                        logger.info("Model sleep state: %s=awake", self._current_model_name)
                except Exception as recovery_error:
                    logger.error("Recovery wake_up failed: %s: %s", recovery_error.__class__.__name__, recovery_error)
                # Always attempt to resume the scheduler to avoid a permanently paused state.
                try:
                    await self._engine.resume_generation()
                    logger.warning("Engine recovered (wake_up + resume_generation). "
                                   "State may still be inconsistent — manual restart recommended "
                                   "if subsequent requests fail.")
                except Exception as resume_error:
                    logger.error(
                        "Recovery resume_generation failed: %s: %s. "
                        "Engine scheduler may be permanently paused. Manual server restart required.",
                        resume_error.__class__.__name__,
                        resume_error,
                    )

                # Re-raise original exception with context
                raise RuntimeError(
                    f"Failed to switch model from {self._current_model_name} to {model_name}: {e}") from e

    async def generate(
        self,
        prompt: PromptType | EngineInput,
        sampling_params: SamplingParams,
        request_id: str,
        **kwargs,
    ) -> AsyncGenerator[RequestOutput, None]:
        """
        Generate completion for prompt.

        Args:
            prompt: Input prompt
            sampling_params: Sampling parameters
            request_id: Unique request ID
            **kwargs: Additional args passed to AsyncLLM.generate()

        Yields:
            RequestOutput: Generation outputs

        Raises:
            RuntimeError: If engine not initialized
        """
        if self._engine is None:
            raise RuntimeError("Engine not initialized.")

        async for output in self._engine.generate(prompt, sampling_params, request_id, **kwargs):
            yield output

    async def encode(
        self,
        prompt: PromptType | EngineInput,
        pooling_params: PoolingParams,
        request_id: str,
        **kwargs,
    ) -> AsyncGenerator[PoolingRequestOutput, None]:
        """
        Encode input for embedding/pooling models.

        Args:
            prompt: Input prompt
            pooling_params: Pooling parameters
            request_id: Unique request ID
            **kwargs: Additional args passed to AsyncLLM.encode()

        Yields:
            PoolingRequestOutput: Encoding outputs

        Raises:
            RuntimeError: If engine not initialized
        """
        if self._engine is None:
            raise RuntimeError("Engine not initialized.")

        async for output in self._engine.encode(prompt, pooling_params, request_id, **kwargs):
            yield output

    async def abort(self, request_id: str | list[str]) -> None:
        """Abort request(s)."""
        if self._engine is not None:
            await self._engine.abort(request_id)

    def shutdown(self):
        """Shutdown engine and cleanup."""
        if self._engine is not None:
            logger.info("Shutting down multi-model engine")
            self._engine.shutdown()
            self._engine = None
        self._sleeping.clear()
        self._current_model_name = None

    def __del__(self):
        """Cleanup on deletion."""
        self.shutdown()

    async def __aenter__(self):
        """Async context manager."""
        return self

    async def __aexit__(self, exc_type, exc_val, exc_tb):
        """Async context manager."""
        self.shutdown()

_current_model_name instance-attribute

_current_model_name: str | None = None

_engine instance-attribute

_engine: AsyncLLM | None = None

_sleeping instance-attribute

_sleeping: dict[str, bool] = {}

_switching_lock instance-attribute

_switching_lock = Lock()

_vllm_configs instance-attribute

_vllm_configs: dict[str, VllmConfig] = {}

available_models property

available_models: list[str]

Return list of available model names.

current_model property

current_model: str | None

Return currently loaded model name.

disable_log_stats instance-attribute

disable_log_stats = disable_log_stats

enable_log_requests instance-attribute

enable_log_requests = enable_log_requests

engine property

engine: AsyncLLM

Return underlying AsyncLLM engine.

model_configs instance-attribute

model_configs = model_configs

model_quant_configs instance-attribute

model_quant_configs: dict[str, str | None] = (
    model_quant_configs or {}
)

usage_context instance-attribute

usage_context = usage_context

__aenter__ async

__aenter__()

Async context manager.

Source code in vllm_gaudi/v1/engine/multi_model_async_llm.py
async def __aenter__(self):
    """Async context manager."""
    return self

__aexit__ async

__aexit__(exc_type, exc_val, exc_tb)

Async context manager.

Source code in vllm_gaudi/v1/engine/multi_model_async_llm.py
async def __aexit__(self, exc_type, exc_val, exc_tb):
    """Async context manager."""
    self.shutdown()

__del__

__del__()

Cleanup on deletion.

Source code in vllm_gaudi/v1/engine/multi_model_async_llm.py
def __del__(self):
    """Cleanup on deletion."""
    self.shutdown()

__init__

__init__(
    model_configs: dict[str, AsyncEngineArgs],
    usage_context: UsageContext = ENGINE_CONTEXT,
    disable_log_stats: bool = False,
    enable_log_requests: bool = False,
    model_quant_configs: dict[str, str | None]
    | None = None,
)

Initialize multi-model manager.

Parameters:

Name Type Description Default
model_configs dict[str, AsyncEngineArgs]

Dict mapping model names to AsyncEngineArgs

required
usage_context UsageContext

Engine usage context

ENGINE_CONTEXT
disable_log_stats bool

Disable stats logging

False
enable_log_requests bool

Enable request logging

False
model_quant_configs dict[str, str | None] | None

Optional dict mapping model names to their QUANT_CONFIG path (INC FP8 calibration JSON)

None
Source code in vllm_gaudi/v1/engine/multi_model_async_llm.py
def __init__(
    self,
    model_configs: dict[str, AsyncEngineArgs],
    usage_context: UsageContext = UsageContext.ENGINE_CONTEXT,
    disable_log_stats: bool = False,
    enable_log_requests: bool = False,
    model_quant_configs: dict[str, str | None] | None = None,
):
    """
    Initialize multi-model manager.

    Args:
        model_configs: Dict mapping model names to AsyncEngineArgs
        usage_context: Engine usage context
        disable_log_stats: Disable stats logging
        enable_log_requests: Enable request logging
        model_quant_configs: Optional dict mapping model names to their
            QUANT_CONFIG path (INC FP8 calibration JSON)
    """
    install_engine_core_patch()

    self._engine: AsyncLLM | None = None
    self._sleeping: dict[str, bool] = {}
    self._current_model_name: str | None = None
    self._vllm_configs: dict[str, VllmConfig] = {}
    self._switching_lock = asyncio.Lock()

    if not model_configs:
        raise ValueError("model_configs cannot be empty")

    self.model_configs = model_configs
    self.usage_context = usage_context
    self.disable_log_stats = disable_log_stats
    self.enable_log_requests = enable_log_requests
    self.model_quant_configs: dict[str, str | None] = model_quant_configs or {}

    # Pre-create VllmConfig for each model
    logger.info("Creating configs for %s models", len(model_configs))
    for name, args in model_configs.items():
        self._vllm_configs[name] = args.create_engine_config(usage_context)
        logger.info("  %s: %s", name, self._vllm_configs[name].model_config.model)

_apply_quant_config_env

_apply_quant_config_env(model_name: str) -> None

Set or unset QUANT_CONFIG in the current process for model_name.

Call it before workers are spawned during initialize(), so child processes inherit the correct quantization calibration file for the selected model. Models switches update worker state through the reconfiguration path.

Source code in vllm_gaudi/v1/engine/multi_model_async_llm.py
def _apply_quant_config_env(self, model_name: str) -> None:
    """Set or unset QUANT_CONFIG in the current process for *model_name*.

    Call it before workers are spawned during ``initialize()``,
    so child processes inherit the correct quantization calibration file for the 
    selected model. Models switches update worker state through the reconfiguration path.
    """
    if model_name in self.model_quant_configs:
        quant_config_path = self.model_quant_configs[model_name]
        if quant_config_path is not None:
            os.environ["QUANT_CONFIG"] = quant_config_path
            logger.info("[quant_config] QUANT_CONFIG=%s (model=%s)", quant_config_path, model_name)
        else:
            os.environ.pop("QUANT_CONFIG", None)
            logger.info("[quant_config] QUANT_CONFIG unset (model=%s)", model_name)
    else:
        logger.info("[quant_config] QUANT_CONFIG preserved from environment (model=%s)", model_name)

_refresh_engine_frontend_config async

_refresh_engine_frontend_config(model_name: str) -> None

Refresh AsyncLLM frontend state to target model config.

Engine core reloads model weights/config in-place, but AsyncLLM frontend keeps its own model_config, renderer, and processors used by API request validation/tokenization. Keep these aligned with the switched model, then restart the background output handler so it picks up the new output_processor / renderer.

Source code in vllm_gaudi/v1/engine/multi_model_async_llm.py
async def _refresh_engine_frontend_config(self, model_name: str) -> None:
    """Refresh AsyncLLM frontend state to target model config.

    Engine core reloads model weights/config in-place, but AsyncLLM frontend
    keeps its own ``model_config``, renderer, and processors used by API
    request validation/tokenization.  Keep these aligned with the switched
    model, then restart the background output handler so it picks up the
    new ``output_processor`` / ``renderer``.
    """
    if self._engine is None:
        raise RuntimeError("Engine not initialized. Call initialize() first.")

    target_config = self._vllm_configs[model_name]
    engine = self._engine

    # --- 1. Cancel the old output_handler before rebuilding processors ---
    old_task = getattr(engine, "output_handler", None)
    if old_task is not None and not old_task.done():
        old_task.cancel()
        with contextlib.suppress(asyncio.CancelledError):
            await old_task
    engine.output_handler = None

    # --- 2. Rebuild config / renderer / processors ---
    engine.vllm_config = target_config
    engine.model_config = target_config.model_config
    engine.observability_config = target_config.observability_config

    if renderer := getattr(engine, "renderer", None):
        with contextlib.suppress(Exception):
            renderer.shutdown()

    engine.renderer = renderer = renderer_from_config(target_config)
    engine.io_processor = get_io_processor(
        target_config,
        renderer,
        target_config.model_config.io_processor_plugin,
    )
    engine.input_processor = InputProcessor(target_config, renderer)
    engine.output_processor = OutputProcessor(
        renderer.tokenizer,
        log_stats=engine.log_stats,
        stream_interval=target_config.scheduler_config.stream_interval,
        tracing_enabled=target_config.observability_config.otlp_traces_endpoint is not None,
    )

    # --- 3. Restart the output handler with the new processors ---
    engine._run_output_handler()

abort async

abort(request_id: str | list[str]) -> None

Abort request(s).

Source code in vllm_gaudi/v1/engine/multi_model_async_llm.py
async def abort(self, request_id: str | list[str]) -> None:
    """Abort request(s)."""
    if self._engine is not None:
        await self._engine.abort(request_id)

encode async

encode(
    prompt: PromptType | EngineInput,
    pooling_params: PoolingParams,
    request_id: str,
    **kwargs,
) -> AsyncGenerator[PoolingRequestOutput, None]

Encode input for embedding/pooling models.

Parameters:

Name Type Description Default
prompt PromptType | EngineInput

Input prompt

required
pooling_params PoolingParams

Pooling parameters

required
request_id str

Unique request ID

required
**kwargs

Additional args passed to AsyncLLM.encode()

{}

Yields:

Name Type Description
PoolingRequestOutput AsyncGenerator[PoolingRequestOutput, None]

Encoding outputs

Raises:

Type Description
RuntimeError

If engine not initialized

Source code in vllm_gaudi/v1/engine/multi_model_async_llm.py
async def encode(
    self,
    prompt: PromptType | EngineInput,
    pooling_params: PoolingParams,
    request_id: str,
    **kwargs,
) -> AsyncGenerator[PoolingRequestOutput, None]:
    """
    Encode input for embedding/pooling models.

    Args:
        prompt: Input prompt
        pooling_params: Pooling parameters
        request_id: Unique request ID
        **kwargs: Additional args passed to AsyncLLM.encode()

    Yields:
        PoolingRequestOutput: Encoding outputs

    Raises:
        RuntimeError: If engine not initialized
    """
    if self._engine is None:
        raise RuntimeError("Engine not initialized.")

    async for output in self._engine.encode(prompt, pooling_params, request_id, **kwargs):
        yield output

generate async

generate(
    prompt: PromptType | EngineInput,
    sampling_params: SamplingParams,
    request_id: str,
    **kwargs,
) -> AsyncGenerator[RequestOutput, None]

Generate completion for prompt.

Parameters:

Name Type Description Default
prompt PromptType | EngineInput

Input prompt

required
sampling_params SamplingParams

Sampling parameters

required
request_id str

Unique request ID

required
**kwargs

Additional args passed to AsyncLLM.generate()

{}

Yields:

Name Type Description
RequestOutput AsyncGenerator[RequestOutput, None]

Generation outputs

Raises:

Type Description
RuntimeError

If engine not initialized

Source code in vllm_gaudi/v1/engine/multi_model_async_llm.py
async def generate(
    self,
    prompt: PromptType | EngineInput,
    sampling_params: SamplingParams,
    request_id: str,
    **kwargs,
) -> AsyncGenerator[RequestOutput, None]:
    """
    Generate completion for prompt.

    Args:
        prompt: Input prompt
        sampling_params: Sampling parameters
        request_id: Unique request ID
        **kwargs: Additional args passed to AsyncLLM.generate()

    Yields:
        RequestOutput: Generation outputs

    Raises:
        RuntimeError: If engine not initialized
    """
    if self._engine is None:
        raise RuntimeError("Engine not initialized.")

    async for output in self._engine.generate(prompt, sampling_params, request_id, **kwargs):
        yield output

get_all_vllm_configs

get_all_vllm_configs() -> dict[str, VllmConfig]

Get all vllm_configs for model registry building.

Returns a shallow copy to prevent external modification.

Returns:

Type Description
dict[str, VllmConfig]

Dictionary mapping model names to their VllmConfig objects

Source code in vllm_gaudi/v1/engine/multi_model_async_llm.py
def get_all_vllm_configs(self) -> dict[str, VllmConfig]:
    """
    Get all vllm_configs for model registry building.

    Returns a shallow copy to prevent external modification.

    Returns:
        Dictionary mapping model names to their VllmConfig objects
    """
    return self._vllm_configs.copy()

get_vllm_config

get_vllm_config(model_name: str) -> VllmConfig

Get VllmConfig for a model.

Source code in vllm_gaudi/v1/engine/multi_model_async_llm.py
def get_vllm_config(self, model_name: str) -> VllmConfig:
    """Get VllmConfig for a model."""
    if model_name not in self._vllm_configs:
        raise ValueError(f"Model '{model_name}' not found. Available: {list(self.model_configs.keys())}")
    return self._vllm_configs[model_name]

initialize async

initialize(model_name: str) -> None

Initialize engine with a model.

Parameters:

Name Type Description Default
model_name str

Model to load (must be in model_configs)

required

Raises:

Type Description
ValueError

If model_name not found

RuntimeError

If already initialized

Source code in vllm_gaudi/v1/engine/multi_model_async_llm.py
async def initialize(self, model_name: str) -> None:
    """
    Initialize engine with a model.

    Args:
        model_name: Model to load (must be in model_configs)

    Raises:
        ValueError: If model_name not found
        RuntimeError: If already initialized
    """
    if model_name not in self.model_configs:
        raise ValueError(f"Model '{model_name}' not found. Available: {list(self.model_configs.keys())}")

    if self._engine is not None:
        raise RuntimeError("Engine already initialized. Use switch_model() instead.")
    logger.info("Initializing engine with: %s", model_name)
    self._apply_quant_config_env(model_name)
    args = self.model_configs[model_name]
    args.disable_log_stats = self.disable_log_stats
    args.enable_log_requests = self.enable_log_requests

    self._engine = AsyncLLM.from_engine_args(
        args,
        start_engine_loop=True,
        usage_context=self.usage_context,
    )
    self._sleeping[model_name] = False
    self._current_model_name = model_name

    logger.info("Engine initialized with: %s", self._vllm_configs[model_name].model_config.model)

shutdown

shutdown()

Shutdown engine and cleanup.

Source code in vllm_gaudi/v1/engine/multi_model_async_llm.py
def shutdown(self):
    """Shutdown engine and cleanup."""
    if self._engine is not None:
        logger.info("Shutting down multi-model engine")
        self._engine.shutdown()
        self._engine = None
    self._sleeping.clear()
    self._current_model_name = None

switch_model async

switch_model(
    model_name: str, drain_timeout: int = 60
) -> dict[str, float | bool | None]

Switch to a different model with error recovery

Steps: 1. Drain pending requests (with timeout) 2. Sleep current model (free KV cache + weights) 3. Unload current model weights 4. Reload new model on the same engine 5. Reinitialize KV cache for new model

If any step fails, attempts to wake up engine to restore state.

Parameters:

Name Type Description Default
model_name str

Target model name

required
drain_timeout int

Seconds to wait for requests to drain

60

Raises:

Type Description
ValueError

If model not found

RuntimeError

If engine not initialized or switch fails

Source code in vllm_gaudi/v1/engine/multi_model_async_llm.py
async def switch_model(
    self,
    model_name: str,
    drain_timeout: int = 60,
) -> dict[str, float | bool | None]:
    """
    Switch to a different model with error recovery

    Steps:
    1. Drain pending requests (with timeout)
    2. Sleep current model (free KV cache + weights)
    3. Unload current model weights
    4. Reload new model on the same engine
    5. Reinitialize KV cache for new model

    If any step fails, attempts to wake up engine to restore state.

    Args:
        model_name: Target model name
        drain_timeout: Seconds to wait for requests to drain

    Raises:
        ValueError: If model not found
        RuntimeError: If engine not initialized or switch fails

    """
    async with self._switching_lock:
        switch_start = time.perf_counter()
        drain_s = 0.0
        reconfigure_s = 0.0

        if self._engine is None:
            raise RuntimeError("Engine not initialized. Call initialize() first.")

        if model_name not in self.model_configs:
            raise ValueError(f"Model '{model_name}' not found. Available: {list(self.model_configs.keys())}")

        if model_name == self._current_model_name:
            logger.info("Model '%s' already loaded.", model_name)
            return {
                "switched": False,
                "drain_s": 0.0,
                "reconfigure_s": 0.0,
                "switch_s": 0.0,
            }

        new_model = self._vllm_configs[model_name].model_config.model

        logger.info("Switching from %s to %s", self._current_model_name, model_name)

        try:
            # Step 1: Drain pending requests
            logger.info("Draining pending requests...")
            drain_start = time.perf_counter()
            try:
                await asyncio.wait_for(
                    self._engine.wait_for_requests_to_drain(drain_timeout),
                    timeout=drain_timeout + 5,
                )
            except asyncio.TimeoutError:
                logger.warning(
                    "Drain timeout (%ss) exceeded; in-flight requests will be aborted "
                    "by the reconfigure step (pause_scheduler mode='abort'). "
                    "Clients whose requests are aborted will receive errors.",
                    drain_timeout,
                )
            finally:
                drain_s = time.perf_counter() - drain_start

            # Step 2: Reconfigure engine core and scheduler in-process
            logger.info("Reconfiguring engine for: %s", model_name)
            serialized_config = cloudpickle.dumps(self._vllm_configs[model_name])
            reconfigure_start = time.perf_counter()
            if model_name in self.model_quant_configs:
                quant_config_path = self.model_quant_configs[model_name]
                reconfigure_result = await self._engine.engine_core.call_utility_async(
                    "gaudi_reconfigure_engine",
                    serialized_config,
                    quant_config_path,
                )
            else:
                reconfigure_result = await self._engine.engine_core.call_utility_async(
                    "gaudi_reconfigure_engine",
                    serialized_config,
                )
            reconfigure_s = time.perf_counter() - reconfigure_start
            logger.info(
                "[gaudi_reconfigure] caller complete: to=%s elapsed=%.2fs",
                model_name,
                reconfigure_s,
            )
            previous_model_name = self._current_model_name
            assert previous_model_name is not None
            await self._refresh_engine_frontend_config(model_name)
            self._sleeping[previous_model_name] = True
            self._sleeping[model_name] = False
            logger.info("Model sleep state: %s=sleeping", previous_model_name)
            logger.info("Model sleep state: %s=awake", model_name)
            self._current_model_name = model_name
            logger.info("Successfully switched from %s to: %s", previous_model_name, new_model)

            result: dict[str, float | bool | None] = {
                "switched": True,
                "drain_s": drain_s,
                "reconfigure_s": reconfigure_s,
                "switch_s": time.perf_counter() - switch_start,
            }
            if isinstance(reconfigure_result, dict):
                result.update(reconfigure_result)
            return result

        except Exception as e:
            logger.error("Model switch failed during %s: %s. Attempting to restore engine state...",
                         e.__class__.__name__, e)
            # Attempt recovery: wake up weights/KV cache if stuck in sleep, then
            # resume the scheduler (which may have been paused by gaudi_reconfigure_engine).
            try:
                logger.info("Attempting to wake up engine for recovery...")
                await self._engine.wake_up(tags=["weights", "kv_cache"])
                if self._current_model_name is not None:
                    self._sleeping[self._current_model_name] = False
                    logger.info("Model sleep state: %s=awake", self._current_model_name)
            except Exception as recovery_error:
                logger.error("Recovery wake_up failed: %s: %s", recovery_error.__class__.__name__, recovery_error)
            # Always attempt to resume the scheduler to avoid a permanently paused state.
            try:
                await self._engine.resume_generation()
                logger.warning("Engine recovered (wake_up + resume_generation). "
                               "State may still be inconsistent — manual restart recommended "
                               "if subsequent requests fail.")
            except Exception as resume_error:
                logger.error(
                    "Recovery resume_generation failed: %s: %s. "
                    "Engine scheduler may be permanently paused. Manual server restart required.",
                    resume_error.__class__.__name__,
                    resume_error,
                )

            # Re-raise original exception with context
            raise RuntimeError(
                f"Failed to switch model from {self._current_model_name} to {model_name}: {e}") from e