OpenAI Compatible Server#
vLLM provides an HTTP server that implements OpenAI’s Completions and Chat API.
You can start the server using Python, or using Docker:
python -m vllm.entrypoints.openai.api_server --model meta-llama/Llama-2-7b-hf --dtype float32 --api-key token-abc123
To call the server, you can use the official OpenAI Python client library, or any other HTTP client.
from openai import OpenAI
client = OpenAI(
base_url="http://localhost:8000/v1",
api_key="token-abc123",
)
completion = client.chat.completions.create(
model="meta-llama/Llama-2-7b-hf",
messages=[
{"role": "system", "content": "You are a helpful assistant."},
{"role": "user", "content": "Hello!"}
]
)
print(completion.choices[0].message)
API Reference#
Please see the OpenAI API Reference for more information on the API. We support all parameters except:
Chat:
tools
, andtool_choice
.Completions:
suffix
.
Extra Parameters#
vLLM supports a set of parameters that are not part of the OpenAI API. In order to use them, you can pass them as extra parameters in the OpenAI client. Or directly merge them into the JSON payload if you are using HTTP call directly.
completion = client.chat.completions.create(
model="meta-llama/Llama-2-7b-hf",
messages=[
{"role": "system", "content": "You are a helpful assistant."},
{"role": "user", "content": "Classify this sentiment: vLLM is wonderful!"}
],
extra_body={
"guided_choice": ["positive", "negative"]
}
)
Extra Parameters for Chat API#
The following sampling parameters (click through to see documentation) are supported.
best_of: Optional[int] = None
use_beam_search: Optional[bool] = False
top_k: Optional[int] = -1
min_p: Optional[float] = 0.0
repetition_penalty: Optional[float] = 1.0
length_penalty: Optional[float] = 1.0
early_stopping: Optional[bool] = False
ignore_eos: Optional[bool] = False
min_tokens: Optional[int] = 0
stop_token_ids: Optional[List[int]] = Field(default_factory=list)
skip_special_tokens: Optional[bool] = True
spaces_between_special_tokens: Optional[bool] = True
The following extra parameters are supported:
echo: Optional[bool] = Field(
default=False,
description=(
"If true, the new message will be prepended with the last message "
"if they belong to the same role."),
)
add_generation_prompt: Optional[bool] = Field(
default=True,
description=
("If true, the generation prompt will be added to the chat template. "
"This is a parameter used by chat template in tokenizer config of the "
"model."),
)
include_stop_str_in_output: Optional[bool] = Field(
default=False,
description=(
"Whether to include the stop string in the output. "
"This is only applied when the stop or stop_token_ids is set."),
)
guided_json: Optional[Union[str, dict, BaseModel]] = Field(
default=None,
description=("If specified, the output will follow the JSON schema."),
)
guided_regex: Optional[str] = Field(
default=None,
description=(
"If specified, the output will follow the regex pattern."),
)
guided_choice: Optional[List[str]] = Field(
default=None,
description=(
"If specified, the output will be exactly one of the choices."),
)
guided_grammar: Optional[str] = Field(
default=None,
description=(
"If specified, the output will follow the context free grammar."),
)
Extra Parameters for Completions API#
The following sampling parameters (click through to see documentation) are supported.
use_beam_search: Optional[bool] = False
top_k: Optional[int] = -1
min_p: Optional[float] = 0.0
repetition_penalty: Optional[float] = 1.0
length_penalty: Optional[float] = 1.0
early_stopping: Optional[bool] = False
stop_token_ids: Optional[List[int]] = Field(default_factory=list)
ignore_eos: Optional[bool] = False
min_tokens: Optional[int] = 0
skip_special_tokens: Optional[bool] = True
spaces_between_special_tokens: Optional[bool] = True
The following extra parameters are supported:
include_stop_str_in_output: Optional[bool] = Field(
default=False,
description=(
"Whether to include the stop string in the output. "
"This is only applied when the stop or stop_token_ids is set."),
)
response_format: Optional[ResponseFormat] = Field(
default=None,
description=
("Similar to chat completion, this parameter specifies the format of "
"output. Only {'type': 'json_object'} or {'type': 'text' } is "
"supported."),
)
guided_json: Optional[Union[str, dict, BaseModel]] = Field(
default=None,
description=("If specified, the output will follow the JSON schema."),
)
guided_regex: Optional[str] = Field(
default=None,
description=(
"If specified, the output will follow the regex pattern."),
)
guided_choice: Optional[List[str]] = Field(
default=None,
description=(
"If specified, the output will be exactly one of the choices."),
)
guided_grammar: Optional[str] = Field(
default=None,
description=(
"If specified, the output will follow the context free grammar."),
)
Chat Template#
In order for the language model to support chat protocol, vLLM requires the model to include a chat template in its tokenizer configuration. The chat template is a Jinja2 template that specifies how are roles, messages, and other chat-specific tokens are encoded in the input.
An example chat template for meta-llama/Llama-2-7b-chat-hf
can be found here
Some models do not provide a chat template even though they are instruction/chat fine-tuned. For those model,
you can manually specify their chat template in the --chat-template
parameter with the file path to the chat
template, or the template in string form. Without a chat template, the server will not be able to process chat
and all chat requests will error.
python -m vllm.entrypoints.openai.api_server \
--model ... \
--chat-template ./path-to-chat-template.jinja
vLLM community provides a set of chat templates for popular models. You can find them in the examples directory here
Command line arguments for the server#
vLLM OpenAI-Compatible RESTful API server.
usage: vllm-openai-server [-h] [--host HOST] [--port PORT]
[--uvicorn-log-level {debug,info,warning,error,critical,trace}]
[--allow-credentials]
[--allowed-origins ALLOWED_ORIGINS]
[--allowed-methods ALLOWED_METHODS]
[--allowed-headers ALLOWED_HEADERS]
[--api-key API_KEY]
[--served-model-name SERVED_MODEL_NAME]
[--lora-modules LORA_MODULES [LORA_MODULES ...]]
[--chat-template CHAT_TEMPLATE]
[--response-role RESPONSE_ROLE]
[--ssl-keyfile SSL_KEYFILE]
[--ssl-certfile SSL_CERTFILE]
[--ssl-ca-certs SSL_CA_CERTS]
[--ssl-cert-reqs SSL_CERT_REQS]
[--root-path ROOT_PATH] [--middleware MIDDLEWARE]
[--model MODEL] [--tokenizer TOKENIZER]
[--revision REVISION]
[--code-revision CODE_REVISION]
[--tokenizer-revision TOKENIZER_REVISION]
[--tokenizer-mode {auto,slow}] [--trust-remote-code]
[--download-dir DOWNLOAD_DIR]
[--load-format {auto,pt,safetensors,npcache,dummy}]
[--dtype {auto,half,float16,bfloat16,float,float32}]
[--kv-cache-dtype {auto,fp8_e5m2}]
[--max-model-len MAX_MODEL_LEN] [--worker-use-ray]
[--pipeline-parallel-size PIPELINE_PARALLEL_SIZE]
[--tensor-parallel-size TENSOR_PARALLEL_SIZE]
[--max-parallel-loading-workers MAX_PARALLEL_LOADING_WORKERS]
[--ray-workers-use-nsight]
[--block-size {8,16,32,128}]
[--enable-prefix-caching] [--use-v2-block-manager]
[--num-lookahead-slots NUM_LOOKAHEAD_SLOTS]
[--seed SEED] [--swap-space SWAP_SPACE]
[--gpu-memory-utilization GPU_MEMORY_UTILIZATION]
[--forced-num-gpu-blocks FORCED_NUM_GPU_BLOCKS]
[--max-num-batched-tokens MAX_NUM_BATCHED_TOKENS]
[--max-num-seqs MAX_NUM_SEQS]
[--max-logprobs MAX_LOGPROBS] [--disable-log-stats]
[--quantization {awq,gptq,squeezellm,None}]
[--enforce-eager]
[--max-context-len-to-capture MAX_CONTEXT_LEN_TO_CAPTURE]
[--disable-custom-all-reduce]
[--tokenizer-pool-size TOKENIZER_POOL_SIZE]
[--tokenizer-pool-type TOKENIZER_POOL_TYPE]
[--tokenizer-pool-extra-config TOKENIZER_POOL_EXTRA_CONFIG]
[--enable-lora] [--max-loras MAX_LORAS]
[--max-lora-rank MAX_LORA_RANK]
[--lora-extra-vocab-size LORA_EXTRA_VOCAB_SIZE]
[--lora-dtype {auto,float16,bfloat16,float32}]
[--max-cpu-loras MAX_CPU_LORAS]
[--device {auto,cuda,neuron,cpu}]
[--image-input-type {pixel_values,image_features}]
[--image-token-id IMAGE_TOKEN_ID]
[--image-input-shape IMAGE_INPUT_SHAPE]
[--image-feature-size IMAGE_FEATURE_SIZE]
[--scheduler-delay-factor SCHEDULER_DELAY_FACTOR]
[--enable-chunked-prefill ENABLE_CHUNKED_PREFILL]
[--engine-use-ray] [--disable-log-requests]
[--max-log-len MAX_LOG_LEN]
Named Arguments#
- --host
host name
- --port
port number
Default: 8000
- --uvicorn-log-level
Possible choices: debug, info, warning, error, critical, trace
log level for uvicorn
Default: “info”
- --allow-credentials
allow credentials
Default: False
- --allowed-origins
allowed origins
Default: [‘*’]
- --allowed-methods
allowed methods
Default: [‘*’]
- --allowed-headers
allowed headers
Default: [‘*’]
- --api-key
If provided, the server will require this key to be presented in the header.
- --served-model-name
The model name used in the API. If not specified, the model name will be the same as the huggingface name.
- --lora-modules
LoRA module configurations in the format name=path. Multiple modules can be specified.
- --chat-template
The file path to the chat template, or the template in single-line form for the specified model
- --response-role
The role name to return if request.add_generation_prompt=true.
Default: “assistant”
- --ssl-keyfile
The file path to the SSL key file
- --ssl-certfile
The file path to the SSL cert file
- --ssl-ca-certs
The CA certificates file
- --ssl-cert-reqs
Whether client certificate is required (see stdlib ssl module’s)
Default: 0
- --root-path
FastAPI root_path when app is behind a path based routing proxy
- --middleware
Additional ASGI middleware to apply to the app. We accept multiple –middleware arguments. The value should be an import path. If a function is provided, vLLM will add it to the server using @app.middleware(‘http’). If a class is provided, vLLM will add it to the server using app.add_middleware().
Default: []
- --model
name or path of the huggingface model to use
Default: “facebook/opt-125m”
- --tokenizer
name or path of the huggingface tokenizer to use
- --revision
the specific model version to use. It can be a branch name, a tag name, or a commit id. If unspecified, will use the default version.
- --code-revision
the specific revision to use for the model code on Hugging Face Hub. It can be a branch name, a tag name, or a commit id. If unspecified, will use the default version.
- --tokenizer-revision
the specific tokenizer version to use. It can be a branch name, a tag name, or a commit id. If unspecified, will use the default version.
- --tokenizer-mode
Possible choices: auto, slow
tokenizer mode. “auto” will use the fast tokenizer if available, and “slow” will always use the slow tokenizer.
Default: “auto”
- --trust-remote-code
trust remote code from huggingface
Default: False
- --download-dir
directory to download and load the weights, default to the default cache dir of huggingface
- --load-format
Possible choices: auto, pt, safetensors, npcache, dummy
The format of the model weights to load. “auto” will try to load the weights in the safetensors format and fall back to the pytorch bin format if safetensors format is not available. “pt” will load the weights in the pytorch bin format. “safetensors” will load the weights in the safetensors format. “npcache” will load the weights in pytorch format and store a numpy cache to speed up the loading. “dummy” will initialize the weights with random values, which is mainly for profiling.
Default: “auto”
- --dtype
Possible choices: auto, half, float16, bfloat16, float, float32
data type for model weights and activations. The “auto” option will use FP16 precision for FP32 and FP16 models, and BF16 precision for BF16 models.
Default: “auto”
- --kv-cache-dtype
Possible choices: auto, fp8_e5m2
Data type for kv cache storage. If “auto”, will use model data type. Note FP8 is not supported when cuda version is lower than 11.8.
Default: “auto”
- --max-model-len
model context length. If unspecified, will be automatically derived from the model.
- --worker-use-ray
use Ray for distributed serving, will be automatically set when using more than 1 GPU
Default: False
- --pipeline-parallel-size, -pp
number of pipeline stages
Default: 1
- --tensor-parallel-size, -tp
number of tensor parallel replicas
Default: 1
- --max-parallel-loading-workers
load model sequentially in multiple batches, to avoid RAM OOM when using tensor parallel and large models
- --ray-workers-use-nsight
If specified, use nsight to profile ray workers
Default: False
- --block-size
Possible choices: 8, 16, 32, 128
token block size
Default: 16
- --enable-prefix-caching
Enables automatic prefix caching
Default: False
- --use-v2-block-manager
Use BlockSpaceMangerV2
Default: False
- --num-lookahead-slots
Experimental scheduling config necessary for speculative decoding. This will be replaced by speculative config in the future; it is present to enable correctness tests until then.
Default: 0
- --seed
random seed
Default: 0
- --swap-space
CPU swap space size (GiB) per GPU
Default: 4
- --gpu-memory-utilization
the fraction of GPU memory to be used for the model executor, which can range from 0 to 1.If unspecified, will use the default value of 0.9.
Default: 0.9
- --forced-num-gpu-blocks
If specified, ignore GPU profiling result and use this numberof GPU blocks. Used for testing preemption.
- --max-num-batched-tokens
maximum number of batched tokens per iteration
- --max-num-seqs
maximum number of sequences per iteration
Default: 256
- --max-logprobs
max number of log probs to return logprobs is specified in SamplingParams
Default: 5
- --disable-log-stats
disable logging statistics
Default: False
- --quantization, -q
Possible choices: awq, gptq, squeezellm, None
Method used to quantize the weights. If None, we first check the quantization_config attribute in the model config file. If that is None, we assume the model weights are not quantized and use dtype to determine the data type of the weights.
- --enforce-eager
Always use eager-mode PyTorch. If False, will use eager mode and CUDA graph in hybrid for maximal performance and flexibility.
Default: False
- --max-context-len-to-capture
maximum context length covered by CUDA graphs. When a sequence has context length larger than this, we fall back to eager mode.
Default: 8192
- --disable-custom-all-reduce
See ParallelConfig
Default: False
- --tokenizer-pool-size
Size of tokenizer pool to use for asynchronous tokenization. If 0, will use synchronous tokenization.
Default: 0
- --tokenizer-pool-type
Type of tokenizer pool to use for asynchronous tokenization. Ignored if tokenizer_pool_size is 0.
Default: “ray”
- --tokenizer-pool-extra-config
Extra config for tokenizer pool. This should be a JSON string that will be parsed into a dictionary. Ignored if tokenizer_pool_size is 0.
- --enable-lora
If True, enable handling of LoRA adapters.
Default: False
- --max-loras
Max number of LoRAs in a single batch.
Default: 1
- --max-lora-rank
Max LoRA rank.
Default: 16
- --lora-extra-vocab-size
Maximum size of extra vocabulary that can be present in a LoRA adapter (added to the base model vocabulary).
Default: 256
- --lora-dtype
Possible choices: auto, float16, bfloat16, float32
Data type for LoRA. If auto, will default to base model dtype.
Default: “auto”
- --max-cpu-loras
Maximum number of LoRAs to store in CPU memory. Must be >= than max_num_seqs. Defaults to max_num_seqs.
- --device
Possible choices: auto, cuda, neuron, cpu
Device type for vLLM execution.
Default: “auto”
- --image-input-type
Possible choices: pixel_values, image_features
The image input type passed into vLLM. Should be one of “pixel_values” or “image_features”.
- --image-token-id
Input id for image token.
- --image-input-shape
The biggest image input shape (worst for memory footprint) given an input type. Only used for vLLM’s profile_run.
- --image-feature-size
The image feature size along the context dimension.
- --scheduler-delay-factor
Apply a delay (of delay factor multiplied by previousprompt latency) before scheduling next prompt.
Default: 0.0
- --enable-chunked-prefill
If True, the prefill requests can be chunked based on the max_num_batched_tokens
Default: False
- --engine-use-ray
use Ray to start the LLM engine in a separate process as the server process.
Default: False
- --disable-log-requests
disable logging requests
Default: False
- --max-log-len
max number of prompt characters or prompt ID numbers being printed in log. Default: unlimited.