Engine Arguments#
Engine arguments control the behavior of the vLLM engine.
For offline inference, they are part of the arguments to
LLM
class.For online serving, they are part of the arguments to
vllm serve
.
For references to all arguments available from vllm serve
see the serve args documentation.
Below, you can find an explanation of every engine argument:
usage: vllm serve [-h] [--model MODEL]
[--task {auto,classify,draft,embed,embedding,generate,reward,score,transcription}]
[--tokenizer TOKENIZER]
[--tokenizer-mode {auto,custom,mistral,slow}]
[--trust-remote-code | --no-trust-remote-code]
[--dtype {auto,bfloat16,float,float16,float32,half}]
[--seed SEED] [--hf-config-path HF_CONFIG_PATH]
[--allowed-local-media-path ALLOWED_LOCAL_MEDIA_PATH]
[--revision REVISION] [--code-revision CODE_REVISION]
[--rope-scaling ROPE_SCALING] [--rope-theta ROPE_THETA]
[--tokenizer-revision TOKENIZER_REVISION]
[--max-model-len MAX_MODEL_LEN]
[--quantization {aqlm,auto-round,awq,awq_marlin,bitblas,bitsandbytes,compressed-tensors,deepspeedfp,experts_int8,fbgemm_fp8,fp8,gguf,gptq,gptq_bitblas,gptq_marlin,gptq_marlin_24,hqq,ipex,marlin,modelopt,moe_wna16,neuron_quant,nvfp4,ptpc_fp8,qqq,quark,torchao,tpu_int8,None}]
[--enforce-eager | --no-enforce-eager]
[--max-seq-len-to-capture MAX_SEQ_LEN_TO_CAPTURE]
[--max-logprobs MAX_LOGPROBS]
[--disable-sliding-window | --no-disable-sliding-window]
[--disable-cascade-attn | --no-disable-cascade-attn]
[--skip-tokenizer-init | --no-skip-tokenizer-init]
[--enable-prompt-embeds | --no-enable-prompt-embeds]
[--served-model-name SERVED_MODEL_NAME [SERVED_MODEL_NAME ...]]
[--disable-async-output-proc]
[--config-format {auto,hf,mistral}] [--hf-token [HF_TOKEN]]
[--hf-overrides HF_OVERRIDES]
[--override-neuron-config OVERRIDE_NEURON_CONFIG]
[--override-pooler-config OVERRIDE_POOLER_CONFIG]
[--logits-processor-pattern LOGITS_PROCESSOR_PATTERN]
[--generation-config GENERATION_CONFIG]
[--override-generation-config OVERRIDE_GENERATION_CONFIG]
[--enable-sleep-mode | --no-enable-sleep-mode]
[--model-impl {auto,vllm,transformers}]
[--load-format {auto,pt,safetensors,npcache,dummy,tensorizer,sharded_state,gguf,bitsandbytes,mistral,runai_streamer,runai_streamer_sharded,fastsafetensors}]
[--download-dir DOWNLOAD_DIR]
[--model-loader-extra-config MODEL_LOADER_EXTRA_CONFIG]
[--ignore-patterns IGNORE_PATTERNS [IGNORE_PATTERNS ...]]
[--use-tqdm-on-load | --no-use-tqdm-on-load]
[--qlora-adapter-name-or-path QLORA_ADAPTER_NAME_OR_PATH]
[--pt-load-map-location PT_LOAD_MAP_LOCATION]
[--guided-decoding-backend {auto,guidance,lm-format-enforcer,outlines,xgrammar}]
[--guided-decoding-disable-fallback | --no-guided-decoding-disable-fallback]
[--guided-decoding-disable-any-whitespace | --no-guided-decoding-disable-any-whitespace]
[--guided-decoding-disable-additional-properties | --no-guided-decoding-disable-additional-properties]
[--enable-reasoning | --no-enable-reasoning]
[--reasoning-parser {deepseek_r1,granite,qwen3}]
[--distributed-executor-backend {external_launcher,mp,ray,uni,None}]
[--pipeline-parallel-size PIPELINE_PARALLEL_SIZE]
[--tensor-parallel-size TENSOR_PARALLEL_SIZE]
[--data-parallel-size DATA_PARALLEL_SIZE]
[--data-parallel-size-local DATA_PARALLEL_SIZE_LOCAL]
[--data-parallel-address DATA_PARALLEL_ADDRESS]
[--data-parallel-rpc-port DATA_PARALLEL_RPC_PORT]
[--enable-expert-parallel | --no-enable-expert-parallel]
[--max-parallel-loading-workers MAX_PARALLEL_LOADING_WORKERS]
[--ray-workers-use-nsight | --no-ray-workers-use-nsight]
[--disable-custom-all-reduce | --no-disable-custom-all-reduce]
[--worker-cls WORKER_CLS]
[--worker-extension-cls WORKER_EXTENSION_CLS]
[--block-size {1,8,16,32,64,128}]
[--gpu-memory-utilization GPU_MEMORY_UTILIZATION]
[--swap-space SWAP_SPACE]
[--kv-cache-dtype {auto,fp8,fp8_e4m3,fp8_e5m2}]
[--num-gpu-blocks-override NUM_GPU_BLOCKS_OVERRIDE]
[--enable-prefix-caching | --no-enable-prefix-caching]
[--prefix-caching-hash-algo {builtin,sha256}]
[--cpu-offload-gb CPU_OFFLOAD_GB]
[--calculate-kv-scales | --no-calculate-kv-scales]
[--tokenizer-pool-size TOKENIZER_POOL_SIZE]
[--tokenizer-pool-type TOKENIZER_POOL_TYPE]
[--tokenizer-pool-extra-config TOKENIZER_POOL_EXTRA_CONFIG]
[--limit-mm-per-prompt LIMIT_MM_PER_PROMPT]
[--mm-processor-kwargs MM_PROCESSOR_KWARGS]
[--disable-mm-preprocessor-cache | --no-disable-mm-preprocessor-cache]
[--enable-lora | --no-enable-lora]
[--enable-lora-bias | --no-enable-lora-bias]
[--max-loras MAX_LORAS] [--max-lora-rank MAX_LORA_RANK]
[--lora-extra-vocab-size LORA_EXTRA_VOCAB_SIZE]
[--lora-dtype {auto,bfloat16,float16}]
[--long-lora-scaling-factors LONG_LORA_SCALING_FACTORS [LONG_LORA_SCALING_FACTORS ...]]
[--max-cpu-loras MAX_CPU_LORAS]
[--fully-sharded-loras | --no-fully-sharded-loras]
[--enable-prompt-adapter | --no-enable-prompt-adapter]
[--max-prompt-adapters MAX_PROMPT_ADAPTERS]
[--max-prompt-adapter-token MAX_PROMPT_ADAPTER_TOKEN]
[--device {auto,cpu,cuda,hpu,neuron,tpu,xpu}]
[--speculative-config SPECULATIVE_CONFIG]
[--show-hidden-metrics-for-version SHOW_HIDDEN_METRICS_FOR_VERSION]
[--otlp-traces-endpoint OTLP_TRACES_ENDPOINT]
[--collect-detailed-traces {all,model,worker,None} [{all,model,worker,None} ...]]
[--max-num-batched-tokens MAX_NUM_BATCHED_TOKENS]
[--max-num-seqs MAX_NUM_SEQS]
[--max-num-partial-prefills MAX_NUM_PARTIAL_PREFILLS]
[--max-long-partial-prefills MAX_LONG_PARTIAL_PREFILLS]
[--cuda-graph-sizes CUDA_GRAPH_SIZES [CUDA_GRAPH_SIZES ...]]
[--long-prefill-token-threshold LONG_PREFILL_TOKEN_THRESHOLD]
[--num-lookahead-slots NUM_LOOKAHEAD_SLOTS]
[--scheduler-delay-factor SCHEDULER_DELAY_FACTOR]
[--preemption-mode {recompute,swap,None}]
[--num-scheduler-steps NUM_SCHEDULER_STEPS]
[--multi-step-stream-outputs | --no-multi-step-stream-outputs]
[--scheduling-policy {fcfs,priority}]
[--enable-chunked-prefill | --no-enable-chunked-prefill]
[--disable-chunked-mm-input | --no-disable-chunked-mm-input]
[--scheduler-cls SCHEDULER_CLS]
[--kv-transfer-config KV_TRANSFER_CONFIG]
[--kv-events-config KV_EVENTS_CONFIG]
[--compilation-config COMPILATION_CONFIG]
[--additional-config ADDITIONAL_CONFIG]
[--use-v2-block-manager] [--disable-log-stats]
Named Arguments#
- --use-v2-block-manager
[DEPRECATED] block manager v1 has been removed and SelfAttnBlockSpaceManager (i.e. block manager v2) is now the default. Setting this flag to True or False has no effect on vLLM behavior.
- --disable-log-stats
Disable logging statistics.
ModelConfig#
Configuration for the model.
- --model
Name or path of the Hugging Face model to use. It is also used as the content for
model_name
tag in metrics output whenserved_model_name
is not specified.Default:
'facebook/opt-125m'
- --task
Possible choices: auto, classify, draft, embed, embedding, generate, reward, score, transcription
The task to use the model for. Each vLLM instance only supports one task, even if the same model can be used for multiple tasks. When the model only supports one task, “auto” can be used to select it; otherwise, you must specify explicitly which task to use.
Default:
'auto'
- --tokenizer
Name or path of the Hugging Face tokenizer to use. If unspecified, model name or path will be used.
- --tokenizer-mode
Possible choices: auto, custom, mistral, slow
Tokenizer mode:
“auto” will use the fast tokenizer if available.
“slow” will always use the slow tokenizer.
“mistral” will always use the tokenizer from
mistral_common
.“custom” will use –tokenizer to select the preregistered tokenizer.
Default:
'auto'
- --trust-remote-code, --no-trust-remote-code
Trust remote code (e.g., from HuggingFace) when downloading the model and tokenizer.
Default:
False
- --dtype
Possible choices: auto, bfloat16, float, float16, float32, half
Data type for model weights and activations:
“auto” will use FP16 precision for FP32 and FP16 models, and BF16 precision for BF16 models.
“half” for FP16. Recommended for AWQ quantization.
“float16” is the same as “half”.
“bfloat16” for a balance between precision and range.
“float” is shorthand for FP32 precision.
“float32” for FP32 precision.
Default:
'auto'
- --seed
Random seed for reproducibility. Initialized to None in V0, but initialized to 0 in V1.
- --hf-config-path
Name or path of the Hugging Face config to use. If unspecified, model name or path will be used.
- --allowed-local-media-path
Allowing API requests to read local images or videos from directories specified by the server file system. This is a security risk. Should only be enabled in trusted environments.
Default:
''
- --revision
The specific model version to use. It can be a branch name, a tag name, or a commit id. If unspecified, will use the default version.
- --code-revision
The specific revision to use for the model code on the Hugging Face Hub. It can be a branch name, a tag name, or a commit id. If unspecified, will use the default version.
- --rope-scaling
RoPE scaling configuration. For example,
{"rope_type":"dynamic","factor":2.0}
.Should either be a valid JSON string or JSON keys passed individually. For example, the following sets of arguments are equivalent:
- `--json-arg '{"key1": "value1", "key2": {"key3": "value2"}}'` - `--json-arg.key1 value1 --json-arg.key2.key3 value2`
Default:
{}
- --rope-theta
RoPE theta. Use with
rope_scaling
. In some cases, changing the RoPE theta improves the performance of the scaled model.- --tokenizer-revision
The specific revision to use for the tokenizer on the Hugging Face Hub. It can be a branch name, a tag name, or a commit id. If unspecified, will use the default version.
- --max-model-len
Model context length (prompt and output). If unspecified, will be automatically derived from the model config.
When passing via
--max-model-len
, supports k/m/g/K/M/G in human-readable format. Examples:1k -> 1000
1K -> 1024
25.6k -> 25,600
- --quantization, -q
Possible choices: aqlm, auto-round, awq, awq_marlin, bitblas, bitsandbytes, compressed-tensors, deepspeedfp, experts_int8, fbgemm_fp8, fp8, gguf, gptq, gptq_bitblas, gptq_marlin, gptq_marlin_24, hqq, ipex, marlin, modelopt, moe_wna16, neuron_quant, nvfp4, ptpc_fp8, qqq, quark, torchao, tpu_int8, None
Method used to quantize the weights. If
None
, we first check thequantization_config
attribute in the model config file. If that isNone
, we assume the model weights are not quantized and usedtype
to determine the data type of the weights.- --enforce-eager, --no-enforce-eager
Whether to always use eager-mode PyTorch. If True, we will disable CUDA graph and always execute the model in eager mode. If False, we will use CUDA graph and eager execution in hybrid for maximal performance and flexibility.
Default:
False
- --max-seq-len-to-capture
Maximum sequence len covered by CUDA graphs. When a sequence has context length larger than this, we fall back to eager mode. Additionally for encoder-decoder models, if the sequence length of the encoder input is larger than this, we fall back to the eager mode.
Default:
8192
- --max-logprobs
Maximum number of log probabilities to return when
logprobs
is specified inSamplingParams
. The default value comes the default for the OpenAI Chat Completions API.Default:
20
- --disable-sliding-window, --no-disable-sliding-window
Whether to disable sliding window. If True, we will disable the sliding window functionality of the model, capping to sliding window size. If the model does not support sliding window, this argument is ignored.
Default:
False
- --disable-cascade-attn, --no-disable-cascade-attn
Disable cascade attention for V1. While cascade attention does not change the mathematical correctness, disabling it could be useful for preventing potential numerical issues. Note that even if this is set to False, cascade attention will be only used when the heuristic tells that it’s beneficial.
Default:
False
- --skip-tokenizer-init, --no-skip-tokenizer-init
Skip initialization of tokenizer and detokenizer. Expects valid
prompt_token_ids
andNone
for prompt from the input. The generated output will contain token ids.Default:
False
- --enable-prompt-embeds, --no-enable-prompt-embeds
If
True
, enables passing text embeddings as inputs via theprompt_embeds
key. Note that enabling this will double the time required for graph compilation.Default:
False
- --served-model-name
The model name(s) used in the API. If multiple names are provided, the server will respond to any of the provided names. The model name in the model field of a response will be the first name in this list. If not specified, the model name will be the same as the
--model
argument. Noted that this name(s) will also be used inmodel_name
tag content of prometheus metrics, if multiple names provided, metrics tag will take the first one.- --disable-async-output-proc
Disable async output processing. This may result in lower performance.
- --config-format
Possible choices: auto, hf, mistral
The format of the model config to load:
“auto” will try to load the config in hf format if available else it will try to load in mistral format.
“hf” will load the config in hf format.
“mistral” will load the config in mistral format.
Default:
'auto'
- --hf-token
The token to use as HTTP bearer authorization for remote files . If
True
, will use the token generated when runninghuggingface-cli login
(stored in~/.huggingface
).- --hf-overrides
If a dictionary, contains arguments to be forwarded to the Hugging Face config. If a callable, it is called to update the HuggingFace config.
Default:
{}
- --override-neuron-config
Initialize non-default neuron config or override default neuron config that are specific to Neuron devices, this argument will be used to configure the neuron config that can not be gathered from the vllm arguments. e.g.
{"cast_logits_dtype": "bloat16"}
.Should either be a valid JSON string or JSON keys passed individually. For example, the following sets of arguments are equivalent:
- `--json-arg '{"key1": "value1", "key2": {"key3": "value2"}}'` - `--json-arg.key1 value1 --json-arg.key2.key3 value2`
Default:
{}
- --override-pooler-config
Initialize non-default pooling config or override default pooling config for the pooling model. e.g.
{"pooling_type": "mean", "normalize": false}
.- --logits-processor-pattern
Optional regex pattern specifying valid logits processor qualified names that can be passed with the
logits_processors
extra completion argument. Defaults toNone
, which allows no processors.- --generation-config
The folder path to the generation config. Defaults to
"auto"
, the generation config will be loaded from model path. If set to"vllm"
, no generation config is loaded, vLLM defaults will be used. If set to a folder path, the generation config will be loaded from the specified folder path. Ifmax_new_tokens
is specified in generation config, then it sets a server-wide limit on the number of output tokens for all requests.Default:
'auto'
- --override-generation-config
Overrides or sets generation config. e.g.
{"temperature": 0.5}
. If used with--generation-config auto
, the override parameters will be merged with the default config from the model. If used with--generation-config vllm
, only the override parameters are used.Should either be a valid JSON string or JSON keys passed individually. For example, the following sets of arguments are equivalent:
- `--json-arg '{"key1": "value1", "key2": {"key3": "value2"}}'` - `--json-arg.key1 value1 --json-arg.key2.key3 value2`
Default:
{}
- --enable-sleep-mode, --no-enable-sleep-mode
Enable sleep mode for the engine (only cuda platform is supported).
Default:
False
- --model-impl
Possible choices: auto, vllm, transformers
Which implementation of the model to use:
“auto” will try to use the vLLM implementation, if it exists, and fall back to the Transformers implementation if no vLLM implementation is available.
“vllm” will use the vLLM model implementation.
“transformers” will use the Transformers model implementation.
Default:
'auto'
LoadConfig#
Configuration for loading the model weights.
- --load-format
Possible choices: auto, pt, safetensors, npcache, dummy, tensorizer, sharded_state, gguf, bitsandbytes, mistral, runai_streamer, runai_streamer_sharded, fastsafetensors
The format of the model weights to load:
“auto” will try to load the weights in the safetensors format and fall back to the pytorch bin format if safetensors format is not available.
“pt” will load the weights in the pytorch bin format.
“safetensors” will load the weights in the safetensors format.
“npcache” will load the weights in pytorch format and store a numpy cache to speed up the loading.
“dummy” will initialize the weights with random values, which is mainly for profiling.
“tensorizer” will use CoreWeave’s tensorizer library for fast weight loading. See the Tensorize vLLM Model script in the Examples section for more information.
“runai_streamer” will load the Safetensors weights using Run:ai Model Streamer.
“bitsandbytes” will load the weights using bitsandbytes quantization.
“sharded_state” will load weights from pre-sharded checkpoint files, supporting efficient loading of tensor-parallel models.
“gguf” will load weights from GGUF format files (details specified in https://github.com/ggml-org/ggml/blob/master/docs/gguf.md).
“mistral” will load weights from consolidated safetensors files used by Mistral models.
Default:
'auto'
- --download-dir
Directory to download and load the weights, default to the default cache directory of Hugging Face.
- --model-loader-extra-config
Extra config for model loader. This will be passed to the model loader corresponding to the chosen load_format.
Should either be a valid JSON string or JSON keys passed individually. For example, the following sets of arguments are equivalent:
- `--json-arg '{"key1": "value1", "key2": {"key3": "value2"}}'` - `--json-arg.key1 value1 --json-arg.key2.key3 value2`
Default:
{}
- --ignore-patterns
The list of patterns to ignore when loading the model. Default to “original/**/*” to avoid repeated loading of llama’s checkpoints.
- --use-tqdm-on-load, --no-use-tqdm-on-load
Whether to enable tqdm for showing progress bar when loading model weights.
Default:
True
- --qlora-adapter-name-or-path
The
--qlora-adapter-name-or-path
has no effect, do not set it, and it will be removed in v0.10.0.- --pt-load-map-location
pt_load_map_location: the map location for loading pytorch checkpoint, to support loading checkpoints can only be loaded on certain devices like “cuda”, this is equivalent to {””: “cuda”}. Another supported format is mapping from different devices like from GPU 1 to GPU 0: {”cuda:1”: “cuda:0”}. Note that when passed from command line, the strings in dictionary needs to be double quoted for json parsing. For more details, see original doc for
map_location
in https://pytorch.org/docs/stable/generated/torch.load.htmlDefault:
cpu
DecodingConfig#
Dataclass which contains the decoding strategy of the engine.
- --guided-decoding-backend
Possible choices: auto, guidance, lm-format-enforcer, outlines, xgrammar
Which engine will be used for guided decoding (JSON schema / regex etc) by default. With “auto”, we will make opinionated choices based on request contents and what the backend libraries currently support, so the behavior is subject to change in each release.
Default:
'auto'
- --guided-decoding-disable-fallback, --no-guided-decoding-disable-fallback
If
True
, vLLM will not fallback to a different backend on error.Default:
False
- --guided-decoding-disable-any-whitespace, --no-guided-decoding-disable-any-whitespace
If
True
, the model will not generate any whitespace during guided decoding. This is only supported for xgrammar and guidance backends.Default:
False
- --guided-decoding-disable-additional-properties, --no-guided-decoding-disable-additional-properties
If
True
, theguidance
backend will not useadditionalProperties
in the JSON schema. This is only supported for theguidance
backend and is used to better align its behaviour withoutlines
andxgrammar
.Default:
False
- --enable-reasoning, --no-enable-reasoning
[DEPRECATED] The
--enable-reasoning
flag is deprecated as of v0.8.6. Use--reasoning-parser
to specify the reasoning parser backend instead. This flag (--enable-reasoning
) will be removed in v0.10.0. When--reasoning-parser
is specified, reasoning mode is automatically enabled.- --reasoning-parser
Possible choices: deepseek_r1, granite, qwen3
Select the reasoning parser depending on the model that you’re using. This is used to parse the reasoning content into OpenAI API format.
Default:
''
ParallelConfig#
Configuration for the distributed execution.
- --distributed-executor-backend
Possible choices: external_launcher, mp, ray, uni, None
Backend to use for distributed model workers, either “ray” or “mp” (multiprocessing). If the product of pipeline_parallel_size and tensor_parallel_size is less than or equal to the number of GPUs available, “mp” will be used to keep processing on a single host. Otherwise, this will default to “ray” if Ray is installed and fail otherwise. Note that tpu and hpu only support Ray for distributed inference.
- --pipeline-parallel-size, -pp
Number of pipeline parallel groups.
Default:
1
- --tensor-parallel-size, -tp
Number of tensor parallel groups.
Default:
1
- --data-parallel-size, -dp
Number of data parallel groups. MoE layers will be sharded according to the product of the tensor parallel size and data parallel size.
Default:
1
- --data-parallel-size-local, -dpl
Number of data parallel replicas to run on this node.
- --data-parallel-address, -dpa
Address of data parallel cluster head-node.
- --data-parallel-rpc-port, -dpp
Port for data parallel RPC communication.
- --enable-expert-parallel, --no-enable-expert-parallel
Use expert parallelism instead of tensor parallelism for MoE layers.
Default:
False
- --max-parallel-loading-workers
Maximum number of parallel loading workers when loading model sequentially in multiple batches. To avoid RAM OOM when using tensor parallel and large models.
- --ray-workers-use-nsight, --no-ray-workers-use-nsight
Whether to profile Ray workers with nsight, see https://docs.ray.io/en/latest/ray-observability/user-guides/profiling.html#profiling-nsight-profiler.
Default:
False
- --disable-custom-all-reduce, --no-disable-custom-all-reduce
Disable the custom all-reduce kernel and fall back to NCCL.
Default:
False
- --worker-cls
The full name of the worker class to use. If “auto”, the worker class will be determined based on the platform.
Default:
'auto'
- --worker-extension-cls
The full name of the worker extension class to use. The worker extension class is dynamically inherited by the worker class. This is used to inject new attributes and methods to the worker class for use in collective_rpc calls.
Default:
''
CacheConfig#
Configuration for the KV cache.
- --block-size
Possible choices: 1, 8, 16, 32, 64, 128
Size of a contiguous cache block in number of tokens. This is ignored on neuron devices and set to
--max-model-len
. On CUDA devices, only block sizes up to 32 are supported. On HPU devices, block size defaults to 128.This config has no static default. If left unspecified by the user, it will be set in
Platform.check_and_update_configs()
based on the current platform.- --gpu-memory-utilization
The fraction of GPU memory to be used for the model executor, which can range from 0 to 1. For example, a value of 0.5 would imply 50% GPU memory utilization. If unspecified, will use the default value of 0.9. This is a per-instance limit, and only applies to the current vLLM instance. It does not matter if you have another vLLM instance running on the same GPU. For example, if you have two vLLM instances running on the same GPU, you can set the GPU memory utilization to 0.5 for each instance.
Default:
0.9
- --swap-space
Size of the CPU swap space per GPU (in GiB).
Default:
4
- --kv-cache-dtype
Possible choices: auto, fp8, fp8_e4m3, fp8_e5m2
Data type for kv cache storage. If “auto”, will use model data type. CUDA 11.8+ supports fp8 (=fp8_e4m3) and fp8_e5m2. ROCm (AMD GPU) supports fp8 (=fp8_e4m3).
Default:
'auto'
- --num-gpu-blocks-override
Number of GPU blocks to use. This overrides the profiled
num_gpu_blocks
if specified. Does nothing ifNone
. Used for testing preemption.- --enable-prefix-caching, --no-enable-prefix-caching
Whether to enable prefix caching. Disabled by default for V0. Enabled by default for V1.
- --prefix-caching-hash-algo
Possible choices: builtin, sha256
Set the hash algorithm for prefix caching:
“builtin” is Python’s built-in hash.
“sha256” is collision resistant but with certain overheads.
Default:
'builtin'
- --cpu-offload-gb
The space in GiB to offload to CPU, per GPU. Default is 0, which means no offloading. Intuitively, this argument can be seen as a virtual way to increase the GPU memory size. For example, if you have one 24 GB GPU and set this to 10, virtually you can think of it as a 34 GB GPU. Then you can load a 13B model with BF16 weight, which requires at least 26GB GPU memory. Note that this requires fast CPU-GPU interconnect, as part of the model is loaded from CPU memory to GPU memory on the fly in each model forward pass.
Default:
0
- --calculate-kv-scales, --no-calculate-kv-scales
This enables dynamic calculation of
k_scale
andv_scale
when kv_cache_dtype is fp8. IfFalse
, the scales will be loaded from the model checkpoint if available. Otherwise, the scales will default to 1.0.Default:
False
TokenizerPoolConfig#
This config is deprecated and will be removed in a future release.
Passing these parameters will have no effect. Please remove them from your
configurations.
- --tokenizer-pool-size
This parameter is deprecated and will be removed in a future release. Passing this parameter will have no effect. Please remove it from your configurations.
Default:
0
- --tokenizer-pool-type
This parameter is deprecated and will be removed in a future release. Passing this parameter will have no effect. Please remove it from your configurations.
Default:
'ray'
- --tokenizer-pool-extra-config
This parameter is deprecated and will be removed in a future release. Passing this parameter will have no effect. Please remove it from your configurations.
Should either be a valid JSON string or JSON keys passed individually. For example, the following sets of arguments are equivalent:
- `--json-arg '{"key1": "value1", "key2": {"key3": "value2"}}'` - `--json-arg.key1 value1 --json-arg.key2.key3 value2`
Default:
{}
MultiModalConfig#
Controls the behavior of multimodal models.
- --limit-mm-per-prompt
The maximum number of input items allowed per prompt for each modality. Defaults to 1 (V0) or 999 (V1) for each modality.
For example, to allow up to 16 images and 2 videos per prompt:
{"images": 16, "videos": 2}
Should either be a valid JSON string or JSON keys passed individually. For example, the following sets of arguments are equivalent:
- `--json-arg '{"key1": "value1", "key2": {"key3": "value2"}}'` - `--json-arg.key1 value1 --json-arg.key2.key3 value2`
Default:
{}
- --mm-processor-kwargs
Overrides for the multi-modal processor obtained from
transformers.AutoProcessor.from_pretrained
.The available overrides depend on the model that is being run.
For example, for Phi-3-Vision:
{"num_crops": 4}
.Should either be a valid JSON string or JSON keys passed individually. For example, the following sets of arguments are equivalent:
- `--json-arg '{"key1": "value1", "key2": {"key3": "value2"}}'` - `--json-arg.key1 value1 --json-arg.key2.key3 value2`
- --disable-mm-preprocessor-cache, --no-disable-mm-preprocessor-cache
If
True
, disable caching of the processed multi-modal inputs.Default:
False
LoRAConfig#
Configuration for LoRA.
- --enable-lora, --no-enable-lora
If True, enable handling of LoRA adapters.
- --enable-lora-bias, --no-enable-lora-bias
Enable bias for LoRA adapters.
Default:
False
- --max-loras
Max number of LoRAs in a single batch.
Default:
1
- --max-lora-rank
Max LoRA rank.
Default:
16
- --lora-extra-vocab-size
Maximum size of extra vocabulary that can be present in a LoRA adapter (added to the base model vocabulary).
Default:
256
- --lora-dtype
Possible choices: auto, bfloat16, float16
Data type for LoRA. If auto, will default to base model dtype.
Default:
'auto'
- --long-lora-scaling-factors
Specify multiple scaling factors (which can be different from base model scaling factor - see eg. Long LoRA) to allow for multiple LoRA adapters trained with those scaling factors to be used at the same time. If not specified, only adapters trained with the base model scaling factor are allowed.
- --max-cpu-loras
Maximum number of LoRAs to store in CPU memory. Must be >= than
max_loras
.- --fully-sharded-loras, --no-fully-sharded-loras
By default, only half of the LoRA computation is sharded with tensor parallelism. Enabling this will use the fully sharded layers. At high sequence length, max rank or tensor parallel size, this is likely faster.
Default:
False
PromptAdapterConfig#
Configuration for PromptAdapters.
- --enable-prompt-adapter, --no-enable-prompt-adapter
If True, enable handling of PromptAdapters.
- --max-prompt-adapters
Max number of PromptAdapters in a batch.
Default:
1
- --max-prompt-adapter-token
Max number of PromptAdapters tokens.
Default:
0
DeviceConfig#
Configuration for the device to use for vLLM execution.
- --device
Possible choices: auto, cpu, cuda, hpu, neuron, tpu, xpu
Device type for vLLM execution.
Default:
'auto'
SpeculativeConfig#
Configuration for speculative decoding.
- --speculative-config
The configurations for speculative decoding. Should be a JSON string.
ObservabilityConfig#
Configuration for observability - metrics and tracing.
- --show-hidden-metrics-for-version
Enable deprecated Prometheus metrics that have been hidden since the specified version. For example, if a previously deprecated metric has been hidden since the v0.7.0 release, you use
--show-hidden-metrics-for-version=0.7
as a temporary escape hatch while you migrate to new metrics. The metric is likely to be removed completely in an upcoming release.- --otlp-traces-endpoint
Target URL to which OpenTelemetry traces will be sent.
- --collect-detailed-traces
Possible choices: all, model, worker, None, model,worker, model,all, worker,model, worker,all, all,model, all,worker
It makes sense to set this only if
--otlp-traces-endpoint
is set. If set, it will collect detailed traces for the specified modules. This involves use of possibly costly and or blocking operations and hence might have a performance impact.Note that collecting detailed timing information for each request can be expensive.
SchedulerConfig#
Scheduler configuration.
- --max-num-batched-tokens
Maximum number of tokens to be processed in a single iteration.
This config has no static default. If left unspecified by the user, it will be set in
EngineArgs.create_engine_config
based on the usage context.- --max-num-seqs
Maximum number of sequences to be processed in a single iteration.
This config has no static default. If left unspecified by the user, it will be set in
EngineArgs.create_engine_config
based on the usage context.- --max-num-partial-prefills
For chunked prefill, the maximum number of sequences that can be partially prefilled concurrently.
Default:
1
- --max-long-partial-prefills
For chunked prefill, the maximum number of prompts longer than long_prefill_token_threshold that will be prefilled concurrently. Setting this less than max_num_partial_prefills will allow shorter prompts to jump the queue in front of longer prompts in some cases, improving latency.
Default:
1
- --cuda-graph-sizes
Cuda graph capture sizes, default is 512.
if one value is provided, then the capture list would follow the pattern: [1, 2, 4] + [i for i in range(8, cuda_graph_sizes + 1, 8)]
more than one value (e.g. 1 2 128) is provided, then the capture list will follow the provided list.
Default:
[512]
- --long-prefill-token-threshold
For chunked prefill, a request is considered long if the prompt is longer than this number of tokens.
Default:
0
- --num-lookahead-slots
The number of slots to allocate per sequence per step, beyond the known token ids. This is used in speculative decoding to store KV activations of tokens which may or may not be accepted.
NOTE: This will be replaced by speculative config in the future; it is present to enable correctness tests until then.
Default:
0
- --scheduler-delay-factor
Apply a delay (of delay factor multiplied by previous prompt latency) before scheduling next prompt.
Default:
0.0
- --preemption-mode
Possible choices: recompute, swap, None
Whether to perform preemption by swapping or recomputation. If not specified, we determine the mode as follows: We use recomputation by default since it incurs lower overhead than swapping. However, when the sequence group has multiple sequences (e.g., beam search), recomputation is not currently supported. In such a case, we use swapping instead.
- --num-scheduler-steps
Maximum number of forward steps per scheduler call.
Default:
1
- --multi-step-stream-outputs, --no-multi-step-stream-outputs
If False, then multi-step will stream outputs at the end of all steps
Default:
True
- --scheduling-policy
Possible choices: fcfs, priority
The scheduling policy to use:
“fcfs” means first come first served, i.e. requests are handled in order of arrival.
“priority” means requests are handled based on given priority (lower value means earlier handling) and time of arrival deciding any ties).
Default:
'fcfs'
- --enable-chunked-prefill, --no-enable-chunked-prefill
If True, prefill requests can be chunked based on the remaining max_num_batched_tokens.
- --disable-chunked-mm-input, --no-disable-chunked-mm-input
If set to true and chunked prefill is enabled, we do not want to partially schedule a multimodal item. Only used in V1 This ensures that if a request has a mixed prompt (like text tokens TTTT followed by image tokens IIIIIIIIII) where only some image tokens can be scheduled (like TTTTIIIII, leaving IIIII), it will be scheduled as TTTT in one step and IIIIIIIIII in the next.
Default:
False
- --scheduler-cls
The scheduler class to use. “vllm.core.scheduler.Scheduler” is the default scheduler. Can be a class directly or the path to a class of form “mod.custom_class”.
Default:
'vllm.core.scheduler.Scheduler'
VllmConfig#
Dataclass which contains all vllm-related configuration. This simplifies passing around the distinct configurations in the codebase.
- --kv-transfer-config
The configurations for distributed KV cache transfer.
Should either be a valid JSON string or JSON keys passed individually. For example, the following sets of arguments are equivalent:
- `--json-arg '{"key1": "value1", "key2": {"key3": "value2"}}'` - `--json-arg.key1 value1 --json-arg.key2.key3 value2`
- --kv-events-config
The configurations for event publishing.
Should either be a valid JSON string or JSON keys passed individually. For example, the following sets of arguments are equivalent:
- `--json-arg '{"key1": "value1", "key2": {"key3": "value2"}}'` - `--json-arg.key1 value1 --json-arg.key2.key3 value2`
- --compilation-config, -O
torch.compile
configuration for the model.When it is a number (0, 1, 2, 3), it will be interpreted as the optimization level.
NOTE: level 0 is the default level without any optimization. level 1 and 2 are for internal testing only. level 3 is the recommended level for production.
Following the convention of traditional compilers, using
-O
without space is also supported.-O3
is equivalent to-O 3
.You can specify the full compilation config like so:
{"level": 3, "cudagraph_capture_sizes": [1, 2, 4, 8]}
Should either be a valid JSON string or JSON keys passed individually. For example, the following sets of arguments are equivalent:
- `--json-arg '{"key1": "value1", "key2": {"key3": "value2"}}'` - `--json-arg.key1 value1 --json-arg.key2.key3 value2`
Default:
{}
- --additional-config
Additional config for specified platform. Different platforms may support different configs. Make sure the configs are valid for the platform you are using. Contents must be hashable.
Default:
{}
Async Engine Arguments#
Additional arguments are available to the asynchronous engine which is used for online serving:
usage: vllm serve [-h] [--disable-log-requests]
Named Arguments#
- --disable-log-requests
Disable logging requests.