Engine Arguments#
Below, you can find an explanation of every engine argument for vLLM:
usage: vllm serve [-h] [--model MODEL] [--tokenizer TOKENIZER]
[--skip-tokenizer-init] [--revision REVISION]
[--code-revision CODE_REVISION]
[--tokenizer-revision TOKENIZER_REVISION]
[--tokenizer-mode {auto,slow}] [--trust-remote-code]
[--download-dir DOWNLOAD_DIR]
[--load-format {auto,pt,safetensors,npcache,dummy,tensorizer,sharded_state,gguf,bitsandbytes}]
[--dtype {auto,half,float16,bfloat16,float,float32}]
[--kv-cache-dtype {auto,fp8,fp8_e5m2,fp8_e4m3}]
[--quantization-param-path QUANTIZATION_PARAM_PATH]
[--max-model-len MAX_MODEL_LEN]
[--guided-decoding-backend {outlines,lm-format-enforcer}]
[--distributed-executor-backend {ray,mp}] [--worker-use-ray]
[--pipeline-parallel-size PIPELINE_PARALLEL_SIZE]
[--tensor-parallel-size TENSOR_PARALLEL_SIZE]
[--max-parallel-loading-workers MAX_PARALLEL_LOADING_WORKERS]
[--ray-workers-use-nsight]
[--block-size {8,16,32,128,256,512,1024,2048}]
[--enable-prefix-caching] [--disable-sliding-window]
[--use-v2-block-manager]
[--num-lookahead-slots NUM_LOOKAHEAD_SLOTS] [--seed SEED]
[--swap-space SWAP_SPACE] [--cpu-offload-gb CPU_OFFLOAD_GB]
[--gpu-memory-utilization GPU_MEMORY_UTILIZATION]
[--num-gpu-blocks-override NUM_GPU_BLOCKS_OVERRIDE]
[--max-num-batched-tokens MAX_NUM_BATCHED_TOKENS]
[--max-num-seqs MAX_NUM_SEQS] [--max-logprobs MAX_LOGPROBS]
[--disable-log-stats]
[--quantization {aqlm,awq,deepspeedfp,tpu_int8,fp8,fbgemm_fp8,marlin,gguf,gptq_marlin_24,gptq_marlin,awq_marlin,gptq,squeezellm,compressed-tensors,bitsandbytes,qqq,experts_int8,None}]
[--rope-scaling ROPE_SCALING] [--rope-theta ROPE_THETA]
[--enforce-eager]
[--max-context-len-to-capture MAX_CONTEXT_LEN_TO_CAPTURE]
[--max-seq-len-to-capture MAX_SEQ_LEN_TO_CAPTURE]
[--disable-custom-all-reduce]
[--tokenizer-pool-size TOKENIZER_POOL_SIZE]
[--tokenizer-pool-type TOKENIZER_POOL_TYPE]
[--tokenizer-pool-extra-config TOKENIZER_POOL_EXTRA_CONFIG]
[--limit-mm-per-prompt LIMIT_MM_PER_PROMPT] [--enable-lora]
[--max-loras MAX_LORAS] [--max-lora-rank MAX_LORA_RANK]
[--lora-extra-vocab-size LORA_EXTRA_VOCAB_SIZE]
[--lora-dtype {auto,float16,bfloat16,float32}]
[--long-lora-scaling-factors LONG_LORA_SCALING_FACTORS]
[--max-cpu-loras MAX_CPU_LORAS] [--fully-sharded-loras]
[--enable-prompt-adapter]
[--max-prompt-adapters MAX_PROMPT_ADAPTERS]
[--max-prompt-adapter-token MAX_PROMPT_ADAPTER_TOKEN]
[--device {auto,cuda,neuron,cpu,openvino,tpu,xpu}]
[--num-scheduler-steps NUM_SCHEDULER_STEPS]
[--scheduler-delay-factor SCHEDULER_DELAY_FACTOR]
[--enable-chunked-prefill [ENABLE_CHUNKED_PREFILL]]
[--speculative-model SPECULATIVE_MODEL]
[--speculative-model-quantization {aqlm,awq,deepspeedfp,tpu_int8,fp8,fbgemm_fp8,marlin,gguf,gptq_marlin_24,gptq_marlin,awq_marlin,gptq,squeezellm,compressed-tensors,bitsandbytes,qqq,experts_int8,None}]
[--num-speculative-tokens NUM_SPECULATIVE_TOKENS]
[--speculative-draft-tensor-parallel-size SPECULATIVE_DRAFT_TENSOR_PARALLEL_SIZE]
[--speculative-max-model-len SPECULATIVE_MAX_MODEL_LEN]
[--speculative-disable-by-batch-size SPECULATIVE_DISABLE_BY_BATCH_SIZE]
[--ngram-prompt-lookup-max NGRAM_PROMPT_LOOKUP_MAX]
[--ngram-prompt-lookup-min NGRAM_PROMPT_LOOKUP_MIN]
[--spec-decoding-acceptance-method {rejection_sampler,typical_acceptance_sampler}]
[--typical-acceptance-sampler-posterior-threshold TYPICAL_ACCEPTANCE_SAMPLER_POSTERIOR_THRESHOLD]
[--typical-acceptance-sampler-posterior-alpha TYPICAL_ACCEPTANCE_SAMPLER_POSTERIOR_ALPHA]
[--disable-logprobs-during-spec-decoding [DISABLE_LOGPROBS_DURING_SPEC_DECODING]]
[--model-loader-extra-config MODEL_LOADER_EXTRA_CONFIG]
[--ignore-patterns IGNORE_PATTERNS]
[--preemption-mode PREEMPTION_MODE]
[--served-model-name SERVED_MODEL_NAME [SERVED_MODEL_NAME ...]]
[--qlora-adapter-name-or-path QLORA_ADAPTER_NAME_OR_PATH]
[--otlp-traces-endpoint OTLP_TRACES_ENDPOINT]
[--collect-detailed-traces COLLECT_DETAILED_TRACES]
Named Arguments#
- --model
Name or path of the huggingface model to use.
Default: “facebook/opt-125m”
- --tokenizer
Name or path of the huggingface tokenizer to use. If unspecified, model name or path will be used.
- --skip-tokenizer-init
Skip initialization of tokenizer and detokenizer
- --revision
The specific model version to use. It can be a branch name, a tag name, or a commit id. If unspecified, will use the default version.
- --code-revision
The specific revision to use for the model code on Hugging Face Hub. It can be a branch name, a tag name, or a commit id. If unspecified, will use the default version.
- --tokenizer-revision
Revision of the huggingface tokenizer to use. It can be a branch name, a tag name, or a commit id. If unspecified, will use the default version.
- --tokenizer-mode
Possible choices: auto, slow
The tokenizer mode.
“auto” will use the fast tokenizer if available.
“slow” will always use the slow tokenizer.
Default: “auto”
- --trust-remote-code
Trust remote code from huggingface.
- --download-dir
Directory to download and load the weights, default to the default cache dir of huggingface.
- --load-format
Possible choices: auto, pt, safetensors, npcache, dummy, tensorizer, sharded_state, gguf, bitsandbytes
The format of the model weights to load.
“auto” will try to load the weights in the safetensors format and fall back to the pytorch bin format if safetensors format is not available.
“pt” will load the weights in the pytorch bin format.
“safetensors” will load the weights in the safetensors format.
“npcache” will load the weights in pytorch format and store a numpy cache to speed up the loading.
“dummy” will initialize the weights with random values, which is mainly for profiling.
“tensorizer” will load the weights using tensorizer from CoreWeave. See the Tensorize vLLM Model script in the Examples section for more information.
“bitsandbytes” will load the weights using bitsandbytes quantization.
Default: “auto”
- --dtype
Possible choices: auto, half, float16, bfloat16, float, float32
Data type for model weights and activations.
“auto” will use FP16 precision for FP32 and FP16 models, and BF16 precision for BF16 models.
“half” for FP16. Recommended for AWQ quantization.
“float16” is the same as “half”.
“bfloat16” for a balance between precision and range.
“float” is shorthand for FP32 precision.
“float32” for FP32 precision.
Default: “auto”
- --kv-cache-dtype
Possible choices: auto, fp8, fp8_e5m2, fp8_e4m3
Data type for kv cache storage. If “auto”, will use model data type. CUDA 11.8+ supports fp8 (=fp8_e4m3) and fp8_e5m2. ROCm (AMD GPU) supports fp8 (=fp8_e4m3)
Default: “auto”
- --quantization-param-path
Path to the JSON file containing the KV cache scaling factors. This should generally be supplied, when KV cache dtype is FP8. Otherwise, KV cache scaling factors default to 1.0, which may cause accuracy issues. FP8_E5M2 (without scaling) is only supported on cuda versiongreater than 11.8. On ROCm (AMD GPU), FP8_E4M3 is instead supported for common inference criteria.
- --max-model-len
Model context length. If unspecified, will be automatically derived from the model config.
- --guided-decoding-backend
Possible choices: outlines, lm-format-enforcer
Which engine will be used for guided decoding (JSON schema / regex etc) by default. Currently support outlines-dev/outlines and noamgat/lm-format-enforcer. Can be overridden per request via guided_decoding_backend parameter.
Default: “outlines”
- --distributed-executor-backend
Possible choices: ray, mp
Backend to use for distributed serving. When more than 1 GPU is used, will be automatically set to “ray” if installed or “mp” (multiprocessing) otherwise.
- --worker-use-ray
Deprecated, use –distributed-executor-backend=ray.
- --pipeline-parallel-size, -pp
Number of pipeline stages.
Default: 1
- --tensor-parallel-size, -tp
Number of tensor parallel replicas.
Default: 1
- --max-parallel-loading-workers
Load model sequentially in multiple batches, to avoid RAM OOM when using tensor parallel and large models.
- --ray-workers-use-nsight
If specified, use nsight to profile Ray workers.
- --block-size
Possible choices: 8, 16, 32, 128, 256, 512, 1024, 2048
Token block size for contiguous chunks of tokens.
Default: 16
- --enable-prefix-caching
Enables automatic prefix caching.
- --disable-sliding-window
Disables sliding window, capping to sliding window size
- --use-v2-block-manager
Use BlockSpaceMangerV2.
- --num-lookahead-slots
Experimental scheduling config necessary for speculative decoding. This will be replaced by speculative config in the future; it is present to enable correctness tests until then.
Default: 0
- --seed
Random seed for operations.
Default: 0
- --swap-space
CPU swap space size (GiB) per GPU.
Default: 4
- --cpu-offload-gb
The space in GiB to offload to CPU, per GPU. Default is 0, which means no offloading. Intuitively, this argument can be seen as a virtual way to increase the GPU memory size. For example, if you have one 24 GB GPU and set this to 10, virtually you can think of it as a 34 GB GPU. Then you can load a 13B model with BF16 weight,which requires at least 26GB GPU memory. Note that this requires fast CPU-GPU interconnect, as part of the model isloaded from CPU memory to GPU memory on the fly in each model forward pass.
Default: 0
- --gpu-memory-utilization
The fraction of GPU memory to be used for the model executor, which can range from 0 to 1. For example, a value of 0.5 would imply 50% GPU memory utilization. If unspecified, will use the default value of 0.9.
Default: 0.9
- --num-gpu-blocks-override
If specified, ignore GPU profiling result and use this numberof GPU blocks. Used for testing preemption.
- --max-num-batched-tokens
Maximum number of batched tokens per iteration.
- --max-num-seqs
Maximum number of sequences per iteration.
Default: 256
- --max-logprobs
Max number of log probs to return logprobs is specified in SamplingParams.
Default: 20
- --disable-log-stats
Disable logging statistics.
- --quantization, -q
Possible choices: aqlm, awq, deepspeedfp, tpu_int8, fp8, fbgemm_fp8, marlin, gguf, gptq_marlin_24, gptq_marlin, awq_marlin, gptq, squeezellm, compressed-tensors, bitsandbytes, qqq, experts_int8, None
Method used to quantize the weights. If None, we first check the quantization_config attribute in the model config file. If that is None, we assume the model weights are not quantized and use dtype to determine the data type of the weights.
- --rope-scaling
RoPE scaling configuration in JSON format. For example, {“type”:”dynamic”,”factor”:2.0}
- --rope-theta
RoPE theta. Use with rope_scaling. In some cases, changing the RoPE theta improves the performance of the scaled model.
- --enforce-eager
Always use eager-mode PyTorch. If False, will use eager mode and CUDA graph in hybrid for maximal performance and flexibility.
- --max-context-len-to-capture
Maximum context length covered by CUDA graphs. When a sequence has context length larger than this, we fall back to eager mode. (DEPRECATED. Use –max-seq-len-to-capture instead)
- --max-seq-len-to-capture
Maximum sequence length covered by CUDA graphs. When a sequence has context length larger than this, we fall back to eager mode.
Default: 8192
- --disable-custom-all-reduce
See ParallelConfig.
- --tokenizer-pool-size
Size of tokenizer pool to use for asynchronous tokenization. If 0, will use synchronous tokenization.
Default: 0
- --tokenizer-pool-type
Type of tokenizer pool to use for asynchronous tokenization. Ignored if tokenizer_pool_size is 0.
Default: “ray”
- --tokenizer-pool-extra-config
Extra config for tokenizer pool. This should be a JSON string that will be parsed into a dictionary. Ignored if tokenizer_pool_size is 0.
- --limit-mm-per-prompt
For each multimodal plugin, limit how many input instances to allow for each prompt. Expects a comma-separated list of items, e.g.: image=16,video=2 allows a maximum of 16 images and 2 videos per prompt. Defaults to 1 for each modality.
- --enable-lora
If True, enable handling of LoRA adapters.
- --max-loras
Max number of LoRAs in a single batch.
Default: 1
- --max-lora-rank
Max LoRA rank.
Default: 16
- --lora-extra-vocab-size
Maximum size of extra vocabulary that can be present in a LoRA adapter (added to the base model vocabulary).
Default: 256
- --lora-dtype
Possible choices: auto, float16, bfloat16, float32
Data type for LoRA. If auto, will default to base model dtype.
Default: “auto”
- --long-lora-scaling-factors
Specify multiple scaling factors (which can be different from base model scaling factor - see eg. Long LoRA) to allow for multiple LoRA adapters trained with those scaling factors to be used at the same time. If not specified, only adapters trained with the base model scaling factor are allowed.
- --max-cpu-loras
Maximum number of LoRAs to store in CPU memory. Must be >= than max_num_seqs. Defaults to max_num_seqs.
- --fully-sharded-loras
By default, only half of the LoRA computation is sharded with tensor parallelism. Enabling this will use the fully sharded layers. At high sequence length, max rank or tensor parallel size, this is likely faster.
- --enable-prompt-adapter
If True, enable handling of PromptAdapters.
- --max-prompt-adapters
Max number of PromptAdapters in a batch.
Default: 1
- --max-prompt-adapter-token
Max number of PromptAdapters tokens
Default: 0
- --device
Possible choices: auto, cuda, neuron, cpu, openvino, tpu, xpu
Device type for vLLM execution.
Default: “auto”
- --num-scheduler-steps
Maximum number of forward steps per scheduler call.
Default: 1
- --scheduler-delay-factor
Apply a delay (of delay factor multiplied by previousprompt latency) before scheduling next prompt.
Default: 0.0
- --enable-chunked-prefill
If set, the prefill requests can be chunked based on the max_num_batched_tokens.
- --speculative-model
The name of the draft model to be used in speculative decoding.
- --speculative-model-quantization
Possible choices: aqlm, awq, deepspeedfp, tpu_int8, fp8, fbgemm_fp8, marlin, gguf, gptq_marlin_24, gptq_marlin, awq_marlin, gptq, squeezellm, compressed-tensors, bitsandbytes, qqq, experts_int8, None
Method used to quantize the weights of speculative model.If None, we first check the quantization_config attribute in the model config file. If that is None, we assume the model weights are not quantized and use dtype to determine the data type of the weights.
- --num-speculative-tokens
The number of speculative tokens to sample from the draft model in speculative decoding.
- --speculative-draft-tensor-parallel-size, -spec-draft-tp
Number of tensor parallel replicas for the draft model in speculative decoding.
- --speculative-max-model-len
The maximum sequence length supported by the draft model. Sequences over this length will skip speculation.
- --speculative-disable-by-batch-size
Disable speculative decoding for new incoming requests if the number of enqueue requests is larger than this value.
- --ngram-prompt-lookup-max
Max size of window for ngram prompt lookup in speculative decoding.
- --ngram-prompt-lookup-min
Min size of window for ngram prompt lookup in speculative decoding.
- --spec-decoding-acceptance-method
Possible choices: rejection_sampler, typical_acceptance_sampler
Specify the acceptance method to use during draft token verification in speculative decoding. Two types of acceptance routines are supported: 1) RejectionSampler which does not allow changing the acceptance rate of draft tokens, 2) TypicalAcceptanceSampler which is configurable, allowing for a higher acceptance rate at the cost of lower quality, and vice versa.
Default: “rejection_sampler”
- --typical-acceptance-sampler-posterior-threshold
Set the lower bound threshold for the posterior probability of a token to be accepted. This threshold is used by the TypicalAcceptanceSampler to make sampling decisions during speculative decoding. Defaults to 0.09
- --typical-acceptance-sampler-posterior-alpha
A scaling factor for the entropy-based threshold for token acceptance in the TypicalAcceptanceSampler. Typically defaults to sqrt of –typical-acceptance-sampler-posterior-threshold i.e. 0.3
- --disable-logprobs-during-spec-decoding
If set to True, token log probabilities are not returned during speculative decoding. If set to False, log probabilities are returned according to the settings in SamplingParams. If not specified, it defaults to True. Disabling log probabilities during speculative decoding reduces latency by skipping logprob calculation in proposal sampling, target sampling, and after accepted tokens are determined.
- --model-loader-extra-config
Extra config for model loader. This will be passed to the model loader corresponding to the chosen load_format. This should be a JSON string that will be parsed into a dictionary.
- --ignore-patterns
The pattern(s) to ignore when loading the model.Default to ‘original/**/*’ to avoid repeated loading of llama’s checkpoints.
Default: []
- --preemption-mode
If ‘recompute’, the engine performs preemption by recomputing; If ‘swap’, the engine performs preemption by block swapping.
- --served-model-name
The model name(s) used in the API. If multiple names are provided, the server will respond to any of the provided names. The model name in the model field of a response will be the first name in this list. If not specified, the model name will be the same as the –model argument. Noted that this name(s)will also be used in model_name tag content of prometheus metrics, if multiple names provided, metricstag will take the first one.
- --qlora-adapter-name-or-path
Name or path of the QLoRA adapter.
- --otlp-traces-endpoint
Target URL to which OpenTelemetry traces will be sent.
- --collect-detailed-traces
Valid choices are model,worker,all. It makes sense to set this only if –otlp-traces-endpoint is set. If set, it will collect detailed traces for the specified modules. This involves use of possibly costly and or blocking operations and hence might have a performance impact.
Async Engine Arguments#
Below are the additional arguments related to the asynchronous engine:
usage: vllm serve [-h] [--engine-use-ray] [--disable-log-requests]
Named Arguments#
- --engine-use-ray
Use Ray to start the LLM engine in a separate process as the server process.(DEPRECATED. This argument is deprecated and will be removed in a future update. Set VLLM_ALLOW_ENGINE_USE_RAY=1 to force use it. See vllm-project/vllm#7045.)
- --disable-log-requests
Disable logging requests.