AsyncLLMEngine#
- class vllm.AsyncLLMEngine(worker_use_ray: bool, engine_use_ray: bool, *args, log_requests: bool = True, start_engine_loop: bool = True, **kwargs)[source]#
An asynchronous wrapper for
LLMEngine
.This class is used to wrap the
LLMEngine
class to make it asynchronous. It uses asyncio to create a background loop that keeps processing incoming requests. TheLLMEngine
is kicked by the generate method when there are requests in the waiting queue. The generate method yields the outputs from theLLMEngine
to the caller.- Parameters:
worker_use_ray – Whether to use Ray for model workers. Required for distributed execution. Should be the same as parallel_config.worker_use_ray.
engine_use_ray – Whether to make LLMEngine a Ray actor. If so, the async frontend will be executed in a separate process as the model workers.
log_requests – Whether to log the requests.
start_engine_loop – If True, the background task to run the engine will be automatically started in the generate call.
*args – Arguments for
LLMEngine
.**kwargs – Arguments for
LLMEngine
.
- async abort(request_id: str) None [source]#
Abort a request.
Abort a submitted request. If the request is finished or not found, this method will be a no-op.
- Parameters:
request_id – The unique id of the request.
- async encode(inputs: str | TextPrompt | TokensPrompt | ExplicitEncoderDecoderPrompt, pooling_params: PoolingParams, request_id: str, lora_request: LoRARequest | None = None, trace_headers: Mapping[str, str] | None = None) AsyncGenerator[EmbeddingRequestOutput, None] [source]#
Generate outputs for a request from an embedding model.
Generate outputs for a request. This method is a coroutine. It adds the request into the waiting queue of the LLMEngine and streams the outputs from the LLMEngine to the caller.
- Parameters:
inputs – The inputs to the LLM. See
PromptInputs
for more details about the format of each input.pooling_params – The pooling parameters of the request.
request_id – The unique id of the request.
lora_request – LoRA request to use for generation, if any.
trace_headers – OpenTelemetry trace headers.
- Yields:
The output EmbeddingRequestOutput objects from the LLMEngine for the request.
- Details:
If the engine is not running, start the background loop, which iteratively invokes
engine_step()
to process the waiting requests.Add the request to the engine’s RequestTracker. On the next background loop, this request will be sent to the underlying engine. Also, a corresponding AsyncStream will be created.
Wait for the request outputs from AsyncStream and yield them.
Example
>>> # Please refer to entrypoints/api_server.py for >>> # the complete example. >>> >>> # initialize the engine and the example input >>> engine = AsyncLLMEngine.from_engine_args(engine_args) >>> example_input = { >>> "input": "What is LLM?", >>> "request_id": 0, >>> } >>> >>> # start the generation >>> results_generator = engine.encode( >>> example_input["input"], >>> PoolingParams(), >>> example_input["request_id"]) >>> >>> # get the results >>> final_output = None >>> async for request_output in results_generator: >>> if await request.is_disconnected(): >>> # Abort the request if the client disconnects. >>> await engine.abort(request_id) >>> # Return or raise an error >>> ... >>> final_output = request_output >>> >>> # Process and return the final output >>> ...
- async engine_step(virtual_engine: int) bool [source]#
Kick the engine to process the waiting requests.
Returns True if there are in-progress requests.
- classmethod from_engine_args(engine_args: AsyncEngineArgs, start_engine_loop: bool = True, usage_context: UsageContext = UsageContext.ENGINE_CONTEXT, stat_loggers: Dict[str, StatLoggerBase] | None = None) AsyncLLMEngine [source]#
Creates an async LLM engine from the engine arguments.
- async generate(inputs: str | TextPrompt | TokensPrompt | ExplicitEncoderDecoderPrompt, sampling_params: SamplingParams, request_id: str, lora_request: LoRARequest | None = None, trace_headers: Mapping[str, str] | None = None, prompt_adapter_request: PromptAdapterRequest | None = None) AsyncGenerator[RequestOutput, None] [source]#
Generate outputs for a request.
Generate outputs for a request. This method is a coroutine. It adds the request into the waiting queue of the LLMEngine and streams the outputs from the LLMEngine to the caller.
- Parameters:
inputs – The inputs to the LLM. See
PromptInputs
for more details about the format of each input.sampling_params – The sampling parameters of the request.
request_id – The unique id of the request.
lora_request – LoRA request to use for generation, if any.
trace_headers – OpenTelemetry trace headers.
prompt_adapter_request – Prompt Adapter request to use for generation, if any.
- Yields:
The output RequestOutput objects from the LLMEngine for the request.
- Details:
If the engine is not running, start the background loop, which iteratively invokes
engine_step()
to process the waiting requests.Add the request to the engine’s RequestTracker. On the next background loop, this request will be sent to the underlying engine. Also, a corresponding AsyncStream will be created.
Wait for the request outputs from AsyncStream and yield them.
Example
>>> # Please refer to entrypoints/api_server.py for >>> # the complete example. >>> >>> # initialize the engine and the example input >>> engine = AsyncLLMEngine.from_engine_args(engine_args) >>> example_input = { >>> "prompt": "What is LLM?", >>> "stream": False, # assume the non-streaming case >>> "temperature": 0.0, >>> "request_id": 0, >>> } >>> >>> # start the generation >>> results_generator = engine.generate( >>> example_input["prompt"], >>> SamplingParams(temperature=example_input["temperature"]), >>> example_input["request_id"]) >>> >>> # get the results >>> final_output = None >>> async for request_output in results_generator: >>> if await request.is_disconnected(): >>> # Abort the request if the client disconnects. >>> await engine.abort(request_id) >>> # Return or raise an error >>> ... >>> final_output = request_output >>> >>> # Process and return the final output >>> ...
- async get_decoding_config() DecodingConfig [source]#
Get the decoding configuration of the vLLM engine.
- async get_parallel_config() ParallelConfig [source]#
Get the parallel configuration of the vLLM engine.
- async get_scheduler_config() SchedulerConfig [source]#
Get the scheduling configuration of the vLLM engine.