LLMEngine#

class vllm.engine.llm_engine.LLMEngine(model_config: ModelConfig, cache_config: CacheConfig, parallel_config: ParallelConfig, scheduler_config: SchedulerConfig, lora_config: LoRAConfig | None, placement_group: PlacementGroup | None, log_stats: bool)[source]#

An LLM engine that receives requests and generates texts.

This is the main class for the vLLM engine. It receives requests from clients and generates texts from the LLM. It includes a tokenizer, a language model (possibly distributed across multiple GPUs), and GPU memory space allocated for intermediate states (aka KV cache). This class utilizes iteration-level scheduling and efficient memory management to maximize the serving throughput.

The LLM class wraps this class for offline batched inference and the AsyncLLMEngine class wraps this class for online serving.

NOTE: The config arguments are derived from the EngineArgs class. For the comprehensive list of arguments, see EngineArgs.

Parameters:
  • model_config – The configuration related to the LLM model.

  • cache_config – The configuration related to the KV cache memory management.

  • parallel_config – The configuration related to distributed execution.

  • scheduler_config – The configuration related to the request scheduler.

  • placement_group – Ray placement group for distributed execution. Required for distributed execution.

  • log_stats – Whether to log statistics.

_init_cache() None[source]#

Profiles the memory usage and initializes the KV cache.

The engine will first conduct a profiling of the existing memory usage. Then, it calculate the maximum possible number of GPU and CPU blocks that can be allocated with the remaining free memory. More details can be found in the profile_num_available_blocks() method from class Worker.

Afterwards, as there may be multiple workers, we take the minimum number of blocks across all workers to ensure this can be applied to all of them.

Finally, the engine will initialize the KV cache with the calculated number of blocks.

Tip

You may limit the usage of GPU memory by adjusting the gpu_memory_utilization parameters.

abort_request(request_id: str | Iterable[str]) None[source]#

Aborts a request(s) with the given ID.

Parameters:

request_id – The ID(s) of the request to abort.

Details:
  • Refer to the abort_seq_group() from class Scheduler.

Example

>>> # initialize engine and add a request with request_id
>>> request_id = str(0)
>>> # abort the request
>>> engine.abort_request(request_id)
add_request(request_id: str, prompt: str | None, sampling_params: SamplingParams, prompt_token_ids: List[int] | None = None, arrival_time: float | None = None, lora_request: LoRARequest | None = None, prefix_pos: int | None = None) None[source]#

Add a request to the engine’s request pool.

The request is added to the request pool and will be processed by the scheduler as engine.step() is called. The exact scheduling policy is determined by the scheduler.

Parameters:
  • request_id – The unique ID of the request.

  • prompt – The prompt string. Can be None if prompt_token_ids is provided.

  • sampling_params – The sampling parameters for text generation.

  • prompt_token_ids – The token IDs of the prompt. If None, we use the tokenizer to convert the prompts to token IDs.

  • arrival_time – The arrival time of the request. If None, we use the current monotonic time.

  • prefix_pos – If not None, we use the given position as the prefix position for each prompt. We will cache the prefix’s KV cache and reuse it for the next request with the same prefix. This is an experimental feature, and may be replaced with automatic prefix caching in the future.

Details:
  • Set arrival_time to the current time if it is None.

  • Set prompt_token_ids to the encoded prompt if it is None.

  • Create best_of number of Sequence objects.

  • Create a SequenceGroup object from the list of Sequence.

  • Add the SequenceGroup object to the scheduler.

Example

>>> # initialize engine
>>> engine = LLMEngine.from_engine_args(engine_args)
>>> # set request arguments
>>> example_prompt = "Who is the president of the United States?"
>>> sampling_params = SamplingParams(temperature=0.0)
>>> request_id = 0
>>>
>>> # add the request to the engine
>>> engine.add_request(
>>>    str(request_id),
>>>    example_prompt,
>>>    SamplingParams(temperature=0.0))
>>> # continue the request processing
>>> ...
step() List[RequestOutput][source]#

Performs one decoding iteration and returns newly generated results.

Overview of the step function

Overview of the step function.#

Details:
  • Step 1: Schedules the sequences to be executed in the next iteration and the token blocks to be swapped in/out/copy.

    • Depending on the scheduling policy, sequences may be preempted/reordered.

    • A Sequence Group (SG) refer to a group of sequences that are generated from the same prompt.

  • Step 2: Calls the workers to execute the model.

  • Step 3: Processes the model output. This mainly includes:

    • Decodes the relevant outputs.

    • Updates the scheduled sequence groups with model outputs based on its sampling parameters (use_beam_search or not).

    • Frees the finished sequence groups.

  • Finally, it creates and returns the newly generated results.

Example

>>> # Please see the example/ folder for more detailed examples.
>>>
>>> # initialize engine and request arguments
>>> engine = LLMEngine.from_engine_args(engine_args)
>>> example_inputs = [(0, "What is LLM?",
>>>    SamplingParams(temperature=0.0))]
>>>
>>> # Start the engine with an event loop
>>> while True:
>>>     if example_inputs:
>>>         req_id, prompt, sampling_params = example_inputs.pop(0)
>>>         engine.add_request(str(req_id), prompt, sampling_params)
>>>
>>>     # continue the request processing
>>>     request_outputs = engine.step()
>>>     for request_output in request_outputs:
>>>         if request_output.finished:
>>>             # return or show the request output
>>>
>>>     if not (engine.has_unfinished_requests() or example_inputs):
>>>         break