OpenAI Compatible Server#
vLLM provides an HTTP server that implements OpenAI’s Completions and Chat API.
You can start the server using Python, or using Docker:
python -m vllm.entrypoints.openai.api_server --model mistralai/Mistral-7B-Instruct-v0.2 --dtype auto --api-key token-abc123
To call the server, you can use the official OpenAI Python client library, or any other HTTP client.
from openai import OpenAI
client = OpenAI(
base_url="http://localhost:8000/v1",
api_key="token-abc123",
)
completion = client.chat.completions.create(
model="mistralai/Mistral-7B-Instruct-v0.2",
messages=[
{"role": "user", "content": "Hello!"}
]
)
print(completion.choices[0].message)
API Reference#
Please see the OpenAI API Reference for more information on the API. We support all parameters except:
Chat:
tools
, andtool_choice
.Completions:
suffix
.
Extra Parameters#
vLLM supports a set of parameters that are not part of the OpenAI API. In order to use them, you can pass them as extra parameters in the OpenAI client. Or directly merge them into the JSON payload if you are using HTTP call directly.
completion = client.chat.completions.create(
model="mistralai/Mistral-7B-Instruct-v0.2",
messages=[
{"role": "user", "content": "Classify this sentiment: vLLM is wonderful!"}
],
extra_body={
"guided_choice": ["positive", "negative"]
}
)
Extra Parameters for Chat API#
The following sampling parameters (click through to see documentation) are supported.
best_of: Optional[int] = None
use_beam_search: Optional[bool] = False
top_k: Optional[int] = -1
min_p: Optional[float] = 0.0
repetition_penalty: Optional[float] = 1.0
length_penalty: Optional[float] = 1.0
early_stopping: Optional[bool] = False
ignore_eos: Optional[bool] = False
min_tokens: Optional[int] = 0
stop_token_ids: Optional[List[int]] = Field(default_factory=list)
skip_special_tokens: Optional[bool] = True
spaces_between_special_tokens: Optional[bool] = True
The following extra parameters are supported:
echo: Optional[bool] = Field(
default=False,
description=(
"If true, the new message will be prepended with the last message "
"if they belong to the same role."),
)
add_generation_prompt: Optional[bool] = Field(
default=True,
description=
("If true, the generation prompt will be added to the chat template. "
"This is a parameter used by chat template in tokenizer config of the "
"model."),
)
include_stop_str_in_output: Optional[bool] = Field(
default=False,
description=(
"Whether to include the stop string in the output. "
"This is only applied when the stop or stop_token_ids is set."),
)
guided_json: Optional[Union[str, dict, BaseModel]] = Field(
default=None,
description=("If specified, the output will follow the JSON schema."),
)
guided_regex: Optional[str] = Field(
default=None,
description=(
"If specified, the output will follow the regex pattern."),
)
guided_choice: Optional[List[str]] = Field(
default=None,
description=(
"If specified, the output will be exactly one of the choices."),
)
guided_grammar: Optional[str] = Field(
default=None,
description=(
"If specified, the output will follow the context free grammar."),
)
guided_decoding_backend: Optional[str] = Field(
default=None,
description=(
"If specified, will override the default guided decoding backend "
"of the server for this specific request. If set, must be either "
"'outlines' / 'lm-format-enforcer'"))
Extra Parameters for Completions API#
The following sampling parameters (click through to see documentation) are supported.
use_beam_search: Optional[bool] = False
top_k: Optional[int] = -1
min_p: Optional[float] = 0.0
repetition_penalty: Optional[float] = 1.0
length_penalty: Optional[float] = 1.0
early_stopping: Optional[bool] = False
stop_token_ids: Optional[List[int]] = Field(default_factory=list)
ignore_eos: Optional[bool] = False
min_tokens: Optional[int] = 0
skip_special_tokens: Optional[bool] = True
spaces_between_special_tokens: Optional[bool] = True
truncate_prompt_tokens: Optional[Annotated[int, Field(ge=1)]] = None
The following extra parameters are supported:
include_stop_str_in_output: Optional[bool] = Field(
default=False,
description=(
"Whether to include the stop string in the output. "
"This is only applied when the stop or stop_token_ids is set."),
)
response_format: Optional[ResponseFormat] = Field(
default=None,
description=
("Similar to chat completion, this parameter specifies the format of "
"output. Only {'type': 'json_object'} or {'type': 'text' } is "
"supported."),
)
guided_json: Optional[Union[str, dict, BaseModel]] = Field(
default=None,
description=("If specified, the output will follow the JSON schema."),
)
guided_regex: Optional[str] = Field(
default=None,
description=(
"If specified, the output will follow the regex pattern."),
)
guided_choice: Optional[List[str]] = Field(
default=None,
description=(
"If specified, the output will be exactly one of the choices."),
)
guided_grammar: Optional[str] = Field(
default=None,
description=(
"If specified, the output will follow the context free grammar."),
)
guided_decoding_backend: Optional[str] = Field(
default=None,
description=(
"If specified, will override the default guided decoding backend "
"of the server for this specific request. If set, must be one of "
"'outlines' / 'lm-format-enforcer'"))
Chat Template#
In order for the language model to support chat protocol, vLLM requires the model to include a chat template in its tokenizer configuration. The chat template is a Jinja2 template that specifies how are roles, messages, and other chat-specific tokens are encoded in the input.
An example chat template for mistralai/Mistral-7B-Instruct-v0.2
can be found here
Some models do not provide a chat template even though they are instruction/chat fine-tuned. For those model,
you can manually specify their chat template in the --chat-template
parameter with the file path to the chat
template, or the template in string form. Without a chat template, the server will not be able to process chat
and all chat requests will error.
python -m vllm.entrypoints.openai.api_server \
--model ... \
--chat-template ./path-to-chat-template.jinja
vLLM community provides a set of chat templates for popular models. You can find them in the examples directory here
Command line arguments for the server#
vLLM OpenAI-Compatible RESTful API server.
usage: vllm-openai-server [-h] [--host HOST] [--port PORT]
[--uvicorn-log-level {debug,info,warning,error,critical,trace}]
[--allow-credentials]
[--allowed-origins ALLOWED_ORIGINS]
[--allowed-methods ALLOWED_METHODS]
[--allowed-headers ALLOWED_HEADERS]
[--api-key API_KEY]
[--served-model-name SERVED_MODEL_NAME [SERVED_MODEL_NAME ...]]
[--lora-modules LORA_MODULES [LORA_MODULES ...]]
[--chat-template CHAT_TEMPLATE]
[--response-role RESPONSE_ROLE]
[--ssl-keyfile SSL_KEYFILE]
[--ssl-certfile SSL_CERTFILE]
[--ssl-ca-certs SSL_CA_CERTS]
[--ssl-cert-reqs SSL_CERT_REQS]
[--root-path ROOT_PATH] [--middleware MIDDLEWARE]
[--model MODEL] [--tokenizer TOKENIZER]
[--skip-tokenizer-init] [--revision REVISION]
[--code-revision CODE_REVISION]
[--tokenizer-revision TOKENIZER_REVISION]
[--tokenizer-mode {auto,slow}] [--trust-remote-code]
[--download-dir DOWNLOAD_DIR]
[--load-format {auto,pt,safetensors,npcache,dummy,tensorizer}]
[--dtype {auto,half,float16,bfloat16,float,float32}]
[--kv-cache-dtype {auto,fp8}]
[--quantization-param-path QUANTIZATION_PARAM_PATH]
[--max-model-len MAX_MODEL_LEN]
[--guided-decoding-backend {outlines,lm-format-enforcer}]
[--worker-use-ray]
[--pipeline-parallel-size PIPELINE_PARALLEL_SIZE]
[--tensor-parallel-size TENSOR_PARALLEL_SIZE]
[--max-parallel-loading-workers MAX_PARALLEL_LOADING_WORKERS]
[--ray-workers-use-nsight] [--block-size {8,16,32}]
[--enable-prefix-caching] [--use-v2-block-manager]
[--num-lookahead-slots NUM_LOOKAHEAD_SLOTS]
[--seed SEED] [--swap-space SWAP_SPACE]
[--gpu-memory-utilization GPU_MEMORY_UTILIZATION]
[--num-gpu-blocks-override NUM_GPU_BLOCKS_OVERRIDE]
[--max-num-batched-tokens MAX_NUM_BATCHED_TOKENS]
[--max-num-seqs MAX_NUM_SEQS]
[--max-logprobs MAX_LOGPROBS] [--disable-log-stats]
[--quantization {aqlm,awq,fp8,gptq,squeezellm,marlin,None}]
[--enforce-eager]
[--max-context-len-to-capture MAX_CONTEXT_LEN_TO_CAPTURE]
[--disable-custom-all-reduce]
[--tokenizer-pool-size TOKENIZER_POOL_SIZE]
[--tokenizer-pool-type TOKENIZER_POOL_TYPE]
[--tokenizer-pool-extra-config TOKENIZER_POOL_EXTRA_CONFIG]
[--enable-lora] [--max-loras MAX_LORAS]
[--max-lora-rank MAX_LORA_RANK]
[--lora-extra-vocab-size LORA_EXTRA_VOCAB_SIZE]
[--lora-dtype {auto,float16,bfloat16,float32}]
[--max-cpu-loras MAX_CPU_LORAS]
[--device {auto,cuda,neuron,cpu}]
[--image-input-type {pixel_values,image_features}]
[--image-token-id IMAGE_TOKEN_ID]
[--image-input-shape IMAGE_INPUT_SHAPE]
[--image-feature-size IMAGE_FEATURE_SIZE]
[--scheduler-delay-factor SCHEDULER_DELAY_FACTOR]
[--enable-chunked-prefill]
[--speculative-model SPECULATIVE_MODEL]
[--num-speculative-tokens NUM_SPECULATIVE_TOKENS]
[--speculative-max-model-len SPECULATIVE_MAX_MODEL_LEN]
[--model-loader-extra-config MODEL_LOADER_EXTRA_CONFIG]
[--engine-use-ray] [--disable-log-requests]
[--max-log-len MAX_LOG_LEN]
Named Arguments#
- --host
host name
- --port
port number
Default: 8000
- --uvicorn-log-level
Possible choices: debug, info, warning, error, critical, trace
log level for uvicorn
Default: “info”
- --allow-credentials
allow credentials
Default: False
- --allowed-origins
allowed origins
Default: [‘*’]
- --allowed-methods
allowed methods
Default: [‘*’]
- --allowed-headers
allowed headers
Default: [‘*’]
- --api-key
If provided, the server will require this key to be presented in the header.
- --served-model-name
The model name(s) used in the API. If multiple names are provided, the server will respond to any of the provided names. The model name in the model field of a response will be the first name in this list. If not specified, the model name will be the same as the –model argument.
- --lora-modules
LoRA module configurations in the format name=path. Multiple modules can be specified.
- --chat-template
The file path to the chat template, or the template in single-line form for the specified model
- --response-role
The role name to return if request.add_generation_prompt=true.
Default: “assistant”
- --ssl-keyfile
The file path to the SSL key file
- --ssl-certfile
The file path to the SSL cert file
- --ssl-ca-certs
The CA certificates file
- --ssl-cert-reqs
Whether client certificate is required (see stdlib ssl module’s)
Default: 0
- --root-path
FastAPI root_path when app is behind a path based routing proxy
- --middleware
Additional ASGI middleware to apply to the app. We accept multiple –middleware arguments. The value should be an import path. If a function is provided, vLLM will add it to the server using @app.middleware(‘http’). If a class is provided, vLLM will add it to the server using app.add_middleware().
Default: []
- --model
Name or path of the huggingface model to use.
Default: “facebook/opt-125m”
- --tokenizer
Name or path of the huggingface tokenizer to use.
- --skip-tokenizer-init
Skip initialization of tokenizer and detokenizer
Default: False
- --revision
The specific model version to use. It can be a branch name, a tag name, or a commit id. If unspecified, will use the default version.
- --code-revision
The specific revision to use for the model code on Hugging Face Hub. It can be a branch name, a tag name, or a commit id. If unspecified, will use the default version.
- --tokenizer-revision
The specific tokenizer version to use. It can be a branch name, a tag name, or a commit id. If unspecified, will use the default version.
- --tokenizer-mode
Possible choices: auto, slow
The tokenizer mode.
“auto” will use the fast tokenizer if available.
“slow” will always use the slow tokenizer.
Default: “auto”
- --trust-remote-code
Trust remote code from huggingface.
Default: False
- --download-dir
Directory to download and load the weights, default to the default cache dir of huggingface.
- --load-format
Possible choices: auto, pt, safetensors, npcache, dummy, tensorizer
The format of the model weights to load.
“auto” will try to load the weights in the safetensors format and fall back to the pytorch bin format if safetensors format is not available.
“pt” will load the weights in the pytorch bin format.
“safetensors” will load the weights in the safetensors format.
“npcache” will load the weights in pytorch format and store a numpy cache to speed up the loading.
“dummy” will initialize the weights with random values, which is mainly for profiling.
“tensorizer” will load the weights using tensorizer from CoreWeave which assumes tensorizer_uri is set to the location of the serialized weights.
Default: “auto”
- --dtype
Possible choices: auto, half, float16, bfloat16, float, float32
Data type for model weights and activations.
“auto” will use FP16 precision for FP32 and FP16 models, and BF16 precision for BF16 models.
“half” for FP16. Recommended for AWQ quantization.
“float16” is the same as “half”.
“bfloat16” for a balance between precision and range.
“float” is shorthand for FP32 precision.
“float32” for FP32 precision.
Default: “auto”
- --kv-cache-dtype
Possible choices: auto, fp8
Data type for kv cache storage. If “auto”, will use model data type. FP8_E5M2 (without scaling) is only supported on cuda version greater than 11.8. On ROCm (AMD GPU), FP8_E4M3 is instead supported for common inference criteria.
Default: “auto”
- --quantization-param-path
Path to the JSON file containing the KV cache scaling factors. This should generally be supplied, when KV cache dtype is FP8. Otherwise, KV cache scaling factors default to 1.0, which may cause accuracy issues. FP8_E5M2 (without scaling) is only supported on cuda versiongreater than 11.8. On ROCm (AMD GPU), FP8_E4M3 is instead supported for common inference criteria.
- --max-model-len
Model context length. If unspecified, will be automatically derived from the model config.
- --guided-decoding-backend
Possible choices: outlines, lm-format-enforcer
Which engine will be used for guided decoding (JSON schema / regex etc) by default. Currently support outlines-dev/outlines and noamgat/lm-format-enforcer. Can be overridden per request via guided_decoding_backend parameter.
Default: “outlines”
- --worker-use-ray
Use Ray for distributed serving, will be automatically set when using more than 1 GPU.
Default: False
- --pipeline-parallel-size, -pp
Number of pipeline stages.
Default: 1
- --tensor-parallel-size, -tp
Number of tensor parallel replicas.
Default: 1
- --max-parallel-loading-workers
Load model sequentially in multiple batches, to avoid RAM OOM when using tensor parallel and large models.
- --ray-workers-use-nsight
If specified, use nsight to profile Ray workers.
Default: False
- --block-size
Possible choices: 8, 16, 32
Token block size for contiguous chunks of tokens.
Default: 16
- --enable-prefix-caching
Enables automatic prefix caching.
Default: False
- --use-v2-block-manager
Use BlockSpaceMangerV2.
Default: False
- --num-lookahead-slots
Experimental scheduling config necessary for speculative decoding. This will be replaced by speculative config in the future; it is present to enable correctness tests until then.
Default: 0
- --seed
Random seed for operations.
Default: 0
- --swap-space
CPU swap space size (GiB) per GPU.
Default: 4
- --gpu-memory-utilization
The fraction of GPU memory to be used for the model executor, which can range from 0 to 1. For example, a value of 0.5 would imply 50% GPU memory utilization. If unspecified, will use the default value of 0.9.
Default: 0.9
- --num-gpu-blocks-override
If specified, ignore GPU profiling result and use this numberof GPU blocks. Used for testing preemption.
- --max-num-batched-tokens
Maximum number of batched tokens per iteration.
- --max-num-seqs
Maximum number of sequences per iteration.
Default: 256
- --max-logprobs
Max number of log probs to return logprobs is specified in SamplingParams.
Default: 5
- --disable-log-stats
Disable logging statistics.
Default: False
- --quantization, -q
Possible choices: aqlm, awq, fp8, gptq, squeezellm, marlin, None
Method used to quantize the weights. If None, we first check the quantization_config attribute in the model config file. If that is None, we assume the model weights are not quantized and use dtype to determine the data type of the weights.
- --enforce-eager
Always use eager-mode PyTorch. If False, will use eager mode and CUDA graph in hybrid for maximal performance and flexibility.
Default: False
- --max-context-len-to-capture
Maximum context length covered by CUDA graphs. When a sequence has context length larger than this, we fall back to eager mode.
Default: 8192
- --disable-custom-all-reduce
See ParallelConfig.
Default: False
- --tokenizer-pool-size
Size of tokenizer pool to use for asynchronous tokenization. If 0, will use synchronous tokenization.
Default: 0
- --tokenizer-pool-type
Type of tokenizer pool to use for asynchronous tokenization. Ignored if tokenizer_pool_size is 0.
Default: “ray”
- --tokenizer-pool-extra-config
Extra config for tokenizer pool. This should be a JSON string that will be parsed into a dictionary. Ignored if tokenizer_pool_size is 0.
- --enable-lora
If True, enable handling of LoRA adapters.
Default: False
- --max-loras
Max number of LoRAs in a single batch.
Default: 1
- --max-lora-rank
Max LoRA rank.
Default: 16
- --lora-extra-vocab-size
Maximum size of extra vocabulary that can be present in a LoRA adapter (added to the base model vocabulary).
Default: 256
- --lora-dtype
Possible choices: auto, float16, bfloat16, float32
Data type for LoRA. If auto, will default to base model dtype.
Default: “auto”
- --max-cpu-loras
Maximum number of LoRAs to store in CPU memory. Must be >= than max_num_seqs. Defaults to max_num_seqs.
- --device
Possible choices: auto, cuda, neuron, cpu
Device type for vLLM execution.
Default: “auto”
- --image-input-type
Possible choices: pixel_values, image_features
The image input type passed into vLLM. Should be one of “pixel_values” or “image_features”.
- --image-token-id
Input id for image token.
- --image-input-shape
The biggest image input shape (worst for memory footprint) given an input type. Only used for vLLM’s profile_run.
- --image-feature-size
The image feature size along the context dimension.
- --scheduler-delay-factor
Apply a delay (of delay factor multiplied by previousprompt latency) before scheduling next prompt.
Default: 0.0
- --enable-chunked-prefill
If set, the prefill requests can be chunked based on the max_num_batched_tokens.
Default: False
- --speculative-model
The name of the draft model to be used in speculative decoding.
- --num-speculative-tokens
The number of speculative tokens to sample from the draft model in speculative decoding.
- --speculative-max-model-len
The maximum sequence length supported by the draft model. Sequences over this length will skip speculation.
- --model-loader-extra-config
Extra config for model loader. This will be passed to the model loader corresponding to the chosen load_format. This should be a JSON string that will be parsed into a dictionary.
- --engine-use-ray
Use Ray to start the LLM engine in a separate process as the server process.
Default: False
- --disable-log-requests
Disable logging requests.
Default: False
- --max-log-len
Max number of prompt characters or prompt ID numbers being printed in log.
Default: Unlimited