LLMEngine#

class vllm.LLMEngine(model_config: ModelConfig, cache_config: CacheConfig, parallel_config: ParallelConfig, scheduler_config: SchedulerConfig, device_config: DeviceConfig, load_config: LoadConfig, lora_config: LoRAConfig | None, speculative_config: SpeculativeConfig | None, decoding_config: DecodingConfig | None, observability_config: ObservabilityConfig | None, prompt_adapter_config: PromptAdapterConfig | None, executor_class: Type[ExecutorBase], log_stats: bool, usage_context: UsageContext = UsageContext.ENGINE_CONTEXT, stat_loggers: Dict[str, StatLoggerBase] | None = None, input_registry: InputRegistry = INPUT_REGISTRY)[source]#

An LLM engine that receives requests and generates texts.

This is the main class for the vLLM engine. It receives requests from clients and generates texts from the LLM. It includes a tokenizer, a language model (possibly distributed across multiple GPUs), and GPU memory space allocated for intermediate states (aka KV cache). This class utilizes iteration-level scheduling and efficient memory management to maximize the serving throughput.

The LLM class wraps this class for offline batched inference and the AsyncLLMEngine class wraps this class for online serving.

The config arguments are derived from EngineArgs. (See Engine Arguments)

Parameters:
  • model_config – The configuration related to the LLM model.

  • cache_config – The configuration related to the KV cache memory management.

  • parallel_config – The configuration related to distributed execution.

  • scheduler_config – The configuration related to the request scheduler.

  • device_config – The configuration related to the device.

  • lora_config (Optional) – The configuration related to serving multi-LoRA.

  • speculative_config (Optional) – The configuration related to speculative decoding.

  • executor_class – The model executor class for managing distributed execution.

  • prompt_adapter_config (Optional) – The configuration related to serving prompt adapters.

  • log_stats – Whether to log statistics.

  • usage_context – Specified entry point, used for usage info collection.

DO_VALIDATE_OUTPUT: ClassVar[bool] = False#

A flag to toggle whether to validate the type of request output.

abort_request(request_id: str | Iterable[str]) None[source]#

Aborts a request(s) with the given ID.

Parameters:

request_id – The ID(s) of the request to abort.

Details:
  • Refer to the abort_seq_group() from class Scheduler.

Example

>>> # initialize engine and add a request with request_id
>>> request_id = str(0)
>>> # abort the request
>>> engine.abort_request(request_id)
add_request(request_id: str, inputs: str | TextPrompt | TokensPrompt | ExplicitEncoderDecoderPrompt, params: SamplingParams | PoolingParams, arrival_time: float | None = None, lora_request: LoRARequest | None = None, trace_headers: Mapping[str, str] | None = None, prompt_adapter_request: PromptAdapterRequest | None = None) None[source]#

Add a request to the engine’s request pool.

The request is added to the request pool and will be processed by the scheduler as engine.step() is called. The exact scheduling policy is determined by the scheduler.

Parameters:
  • request_id – The unique ID of the request.

  • inputs – The inputs to the LLM. See PromptInputs for more details about the format of each input.

  • params – Parameters for sampling or pooling. SamplingParams for text generation. PoolingParams for pooling.

  • arrival_time – The arrival time of the request. If None, we use the current monotonic time.

  • trace_headers – OpenTelemetry trace headers.

Details:
  • Set arrival_time to the current time if it is None.

  • Set prompt_token_ids to the encoded prompt if it is None.

  • Create best_of number of Sequence objects.

  • Create a SequenceGroup object from the list of Sequence.

  • Add the SequenceGroup object to the scheduler.

Example

>>> # initialize engine
>>> engine = LLMEngine.from_engine_args(engine_args)
>>> # set request arguments
>>> example_prompt = "Who is the president of the United States?"
>>> sampling_params = SamplingParams(temperature=0.0)
>>> request_id = 0
>>>
>>> # add the request to the engine
>>> engine.add_request(
>>>    str(request_id),
>>>    example_prompt,
>>>    SamplingParams(temperature=0.0))
>>> # continue the request processing
>>> ...
do_log_stats(scheduler_outputs: SchedulerOutputs | None = None, model_output: List[SamplerOutput] | None = None) None[source]#

Forced log when no requests active.

classmethod from_engine_args(engine_args: EngineArgs, usage_context: UsageContext = UsageContext.ENGINE_CONTEXT, stat_loggers: Dict[str, StatLoggerBase] | None = None) LLMEngine[source]#

Creates an LLM engine from the engine arguments.

get_decoding_config() DecodingConfig[source]#

Gets the decoding configuration.

get_lora_config() LoRAConfig[source]#

Gets the LoRA configuration.

get_model_config() ModelConfig[source]#

Gets the model configuration.

get_num_unfinished_requests() int[source]#

Gets the number of unfinished requests.

get_parallel_config() ParallelConfig[source]#

Gets the parallel configuration.

get_scheduler_config() SchedulerConfig[source]#

Gets the scheduler configuration.

has_unfinished_requests() bool[source]#

Returns True if there are unfinished requests.

has_unfinished_requests_for_virtual_engine(virtual_engine: int) bool[source]#

Returns True if there are unfinished requests for the virtual engine.

step() List[RequestOutput | EmbeddingRequestOutput][source]#

Performs one decoding iteration and returns newly generated results.

Overview of the step function

Overview of the step function.#

Details:
  • Step 1: Schedules the sequences to be executed in the next iteration and the token blocks to be swapped in/out/copy.

    • Depending on the scheduling policy, sequences may be preempted/reordered.

    • A Sequence Group (SG) refer to a group of sequences that are generated from the same prompt.

  • Step 2: Calls the distributed executor to execute the model.

  • Step 3: Processes the model output. This mainly includes:

    • Decodes the relevant outputs.

    • Updates the scheduled sequence groups with model outputs based on its sampling parameters (use_beam_search or not).

    • Frees the finished sequence groups.

  • Finally, it creates and returns the newly generated results.

Example

>>> # Please see the example/ folder for more detailed examples.
>>>
>>> # initialize engine and request arguments
>>> engine = LLMEngine.from_engine_args(engine_args)
>>> example_inputs = [(0, "What is LLM?",
>>>    SamplingParams(temperature=0.0))]
>>>
>>> # Start the engine with an event loop
>>> while True:
>>>     if example_inputs:
>>>         req_id, prompt, sampling_params = example_inputs.pop(0)
>>>         engine.add_request(str(req_id),prompt,sampling_params)
>>>
>>>     # continue the request processing
>>>     request_outputs = engine.step()
>>>     for request_output in request_outputs:
>>>         if request_output.finished:
>>>             # return or show the request output
>>>
>>>     if not (engine.has_unfinished_requests() or example_inputs):
>>>         break