Engine Arguments#

Below, you can find an explanation of every engine argument for vLLM:

usage: -m vllm.entrypoints.openai.api_server [-h] [--model MODEL]
                                             [--tokenizer TOKENIZER]
                                             [--skip-tokenizer-init]
                                             [--revision REVISION]
                                             [--code-revision CODE_REVISION]
                                             [--tokenizer-revision TOKENIZER_REVISION]
                                             [--tokenizer-mode {auto,slow}]
                                             [--trust-remote-code]
                                             [--download-dir DOWNLOAD_DIR]
                                             [--load-format {auto,pt,safetensors,npcache,dummy,tensorizer}]
                                             [--dtype {auto,half,float16,bfloat16,float,float32}]
                                             [--kv-cache-dtype {auto,fp8}]
                                             [--quantization-param-path QUANTIZATION_PARAM_PATH]
                                             [--max-model-len MAX_MODEL_LEN]
                                             [--guided-decoding-backend {outlines,lm-format-enforcer}]
                                             [--worker-use-ray]
                                             [--pipeline-parallel-size PIPELINE_PARALLEL_SIZE]
                                             [--tensor-parallel-size TENSOR_PARALLEL_SIZE]
                                             [--max-parallel-loading-workers MAX_PARALLEL_LOADING_WORKERS]
                                             [--ray-workers-use-nsight]
                                             [--block-size {8,16,32}]
                                             [--enable-prefix-caching]
                                             [--use-v2-block-manager]
                                             [--num-lookahead-slots NUM_LOOKAHEAD_SLOTS]
                                             [--seed SEED]
                                             [--swap-space SWAP_SPACE]
                                             [--gpu-memory-utilization GPU_MEMORY_UTILIZATION]
                                             [--num-gpu-blocks-override NUM_GPU_BLOCKS_OVERRIDE]
                                             [--max-num-batched-tokens MAX_NUM_BATCHED_TOKENS]
                                             [--max-num-seqs MAX_NUM_SEQS]
                                             [--max-logprobs MAX_LOGPROBS]
                                             [--disable-log-stats]
                                             [--quantization {aqlm,awq,fp8,gptq,squeezellm,gptq_marlin,marlin,None}]
                                             [--enforce-eager]
                                             [--max-context-len-to-capture MAX_CONTEXT_LEN_TO_CAPTURE]
                                             [--max-seq_len-to-capture MAX_SEQ_LEN_TO_CAPTURE]
                                             [--disable-custom-all-reduce]
                                             [--tokenizer-pool-size TOKENIZER_POOL_SIZE]
                                             [--tokenizer-pool-type TOKENIZER_POOL_TYPE]
                                             [--tokenizer-pool-extra-config TOKENIZER_POOL_EXTRA_CONFIG]
                                             [--enable-lora]
                                             [--max-loras MAX_LORAS]
                                             [--max-lora-rank MAX_LORA_RANK]
                                             [--lora-extra-vocab-size LORA_EXTRA_VOCAB_SIZE]
                                             [--lora-dtype {auto,float16,bfloat16,float32}]
                                             [--max-cpu-loras MAX_CPU_LORAS]
                                             [--fully-sharded-loras]
                                             [--device {auto,cuda,neuron,cpu}]
                                             [--image-input-type {pixel_values,image_features}]
                                             [--image-token-id IMAGE_TOKEN_ID]
                                             [--image-input-shape IMAGE_INPUT_SHAPE]
                                             [--image-feature-size IMAGE_FEATURE_SIZE]
                                             [--scheduler-delay-factor SCHEDULER_DELAY_FACTOR]
                                             [--enable-chunked-prefill]
                                             [--speculative-model SPECULATIVE_MODEL]
                                             [--num-speculative-tokens NUM_SPECULATIVE_TOKENS]
                                             [--speculative-max-model-len SPECULATIVE_MAX_MODEL_LEN]
                                             [--speculative-disable-by-batch-size SPECULATIVE_DISABLE_BY_BATCH_SIZE]
                                             [--ngram-prompt-lookup-max NGRAM_PROMPT_LOOKUP_MAX]
                                             [--ngram-prompt-lookup-min NGRAM_PROMPT_LOOKUP_MIN]
                                             [--model-loader-extra-config MODEL_LOADER_EXTRA_CONFIG]
                                             [--served-model-name SERVED_MODEL_NAME [SERVED_MODEL_NAME ...]]

Named Arguments#

--model

Name or path of the huggingface model to use.

Default: “facebook/opt-125m”

--tokenizer

Name or path of the huggingface tokenizer to use.

--skip-tokenizer-init

Skip initialization of tokenizer and detokenizer

--revision

The specific model version to use. It can be a branch name, a tag name, or a commit id. If unspecified, will use the default version.

--code-revision

The specific revision to use for the model code on Hugging Face Hub. It can be a branch name, a tag name, or a commit id. If unspecified, will use the default version.

--tokenizer-revision

The specific tokenizer version to use. It can be a branch name, a tag name, or a commit id. If unspecified, will use the default version.

--tokenizer-mode

Possible choices: auto, slow

The tokenizer mode.

  • “auto” will use the fast tokenizer if available.

  • “slow” will always use the slow tokenizer.

Default: “auto”

--trust-remote-code

Trust remote code from huggingface.

--download-dir

Directory to download and load the weights, default to the default cache dir of huggingface.

--load-format

Possible choices: auto, pt, safetensors, npcache, dummy, tensorizer

The format of the model weights to load.

  • “auto” will try to load the weights in the safetensors format and fall back to the pytorch bin format if safetensors format is not available.

  • “pt” will load the weights in the pytorch bin format.

  • “safetensors” will load the weights in the safetensors format.

  • “npcache” will load the weights in pytorch format and store a numpy cache to speed up the loading.

  • “dummy” will initialize the weights with random values, which is mainly for profiling.

  • “tensorizer” will load the weights using tensorizer from CoreWeave which assumes tensorizer_uri is set to the location of the serialized weights.

Default: “auto”

--dtype

Possible choices: auto, half, float16, bfloat16, float, float32

Data type for model weights and activations.

  • “auto” will use FP16 precision for FP32 and FP16 models, and BF16 precision for BF16 models.

  • “half” for FP16. Recommended for AWQ quantization.

  • “float16” is the same as “half”.

  • “bfloat16” for a balance between precision and range.

  • “float” is shorthand for FP32 precision.

  • “float32” for FP32 precision.

Default: “auto”

--kv-cache-dtype

Possible choices: auto, fp8

Data type for kv cache storage. If “auto”, will use model data type. FP8_E5M2 (without scaling) is only supported on cuda version greater than 11.8. On ROCm (AMD GPU), FP8_E4M3 is instead supported for common inference criteria.

Default: “auto”

--quantization-param-path

Path to the JSON file containing the KV cache scaling factors. This should generally be supplied, when KV cache dtype is FP8. Otherwise, KV cache scaling factors default to 1.0, which may cause accuracy issues. FP8_E5M2 (without scaling) is only supported on cuda versiongreater than 11.8. On ROCm (AMD GPU), FP8_E4M3 is instead supported for common inference criteria.

--max-model-len

Model context length. If unspecified, will be automatically derived from the model config.

--guided-decoding-backend

Possible choices: outlines, lm-format-enforcer

Which engine will be used for guided decoding (JSON schema / regex etc) by default. Currently support outlines-dev/outlines and noamgat/lm-format-enforcer. Can be overridden per request via guided_decoding_backend parameter.

Default: “outlines”

--worker-use-ray

Use Ray for distributed serving, will be automatically set when using more than 1 GPU.

--pipeline-parallel-size, -pp

Number of pipeline stages.

Default: 1

--tensor-parallel-size, -tp

Number of tensor parallel replicas.

Default: 1

--max-parallel-loading-workers

Load model sequentially in multiple batches, to avoid RAM OOM when using tensor parallel and large models.

--ray-workers-use-nsight

If specified, use nsight to profile Ray workers.

--block-size

Possible choices: 8, 16, 32

Token block size for contiguous chunks of tokens.

Default: 16

--enable-prefix-caching

Enables automatic prefix caching.

--use-v2-block-manager

Use BlockSpaceMangerV2.

--num-lookahead-slots

Experimental scheduling config necessary for speculative decoding. This will be replaced by speculative config in the future; it is present to enable correctness tests until then.

Default: 0

--seed

Random seed for operations.

Default: 0

--swap-space

CPU swap space size (GiB) per GPU.

Default: 4

--gpu-memory-utilization

The fraction of GPU memory to be used for the model executor, which can range from 0 to 1. For example, a value of 0.5 would imply 50% GPU memory utilization. If unspecified, will use the default value of 0.9.

Default: 0.9

--num-gpu-blocks-override

If specified, ignore GPU profiling result and use this numberof GPU blocks. Used for testing preemption.

--max-num-batched-tokens

Maximum number of batched tokens per iteration.

--max-num-seqs

Maximum number of sequences per iteration.

Default: 256

--max-logprobs

Max number of log probs to return logprobs is specified in SamplingParams.

Default: 5

--disable-log-stats

Disable logging statistics.

--quantization, -q

Possible choices: aqlm, awq, fp8, gptq, squeezellm, gptq_marlin, marlin, None

Method used to quantize the weights. If None, we first check the quantization_config attribute in the model config file. If that is None, we assume the model weights are not quantized and use dtype to determine the data type of the weights.

--enforce-eager

Always use eager-mode PyTorch. If False, will use eager mode and CUDA graph in hybrid for maximal performance and flexibility.

--max-context-len-to-capture

Maximum context length covered by CUDA graphs. When a sequence has context length larger than this, we fall back to eager mode. (DEPRECATED. Use –max-seq_len-to-capture instead)

--max-seq_len-to-capture

Maximum sequence length covered by CUDA graphs. When a sequence has context length larger than this, we fall back to eager mode.

Default: 8192

--disable-custom-all-reduce

See ParallelConfig.

--tokenizer-pool-size

Size of tokenizer pool to use for asynchronous tokenization. If 0, will use synchronous tokenization.

Default: 0

--tokenizer-pool-type

Type of tokenizer pool to use for asynchronous tokenization. Ignored if tokenizer_pool_size is 0.

Default: “ray”

--tokenizer-pool-extra-config

Extra config for tokenizer pool. This should be a JSON string that will be parsed into a dictionary. Ignored if tokenizer_pool_size is 0.

--enable-lora

If True, enable handling of LoRA adapters.

--max-loras

Max number of LoRAs in a single batch.

Default: 1

--max-lora-rank

Max LoRA rank.

Default: 16

--lora-extra-vocab-size

Maximum size of extra vocabulary that can be present in a LoRA adapter (added to the base model vocabulary).

Default: 256

--lora-dtype

Possible choices: auto, float16, bfloat16, float32

Data type for LoRA. If auto, will default to base model dtype.

Default: “auto”

--max-cpu-loras

Maximum number of LoRAs to store in CPU memory. Must be >= than max_num_seqs. Defaults to max_num_seqs.

--fully-sharded-loras

By default, only half of the LoRA computation is sharded with tensor parallelism. Enabling this will use the fully sharded layers. At high sequence length, max rank or tensor parallel size, this is likely faster.

--device

Possible choices: auto, cuda, neuron, cpu

Device type for vLLM execution.

Default: “auto”

--image-input-type

Possible choices: pixel_values, image_features

The image input type passed into vLLM. Should be one of “pixel_values” or “image_features”.

--image-token-id

Input id for image token.

--image-input-shape

The biggest image input shape (worst for memory footprint) given an input type. Only used for vLLM’s profile_run.

--image-feature-size

The image feature size along the context dimension.

--scheduler-delay-factor

Apply a delay (of delay factor multiplied by previousprompt latency) before scheduling next prompt.

Default: 0.0

--enable-chunked-prefill

If set, the prefill requests can be chunked based on the max_num_batched_tokens.

--speculative-model

The name of the draft model to be used in speculative decoding.

--num-speculative-tokens

The number of speculative tokens to sample from the draft model in speculative decoding.

--speculative-max-model-len

The maximum sequence length supported by the draft model. Sequences over this length will skip speculation.

--speculative-disable-by-batch-size

Disable speculative decoding for new incoming requests if the number of enqueue requests is larger than this value.

--ngram-prompt-lookup-max

Max size of window for ngram prompt lookup in speculative decoding.

--ngram-prompt-lookup-min

Min size of window for ngram prompt lookup in speculative decoding.

--model-loader-extra-config

Extra config for model loader. This will be passed to the model loader corresponding to the chosen load_format. This should be a JSON string that will be parsed into a dictionary.

--served-model-name

The model name(s) used in the API. If multiple names are provided, the server will respond to any of the provided names. The model name in the model field of a response will be the first name in this list. If not specified, the model name will be the same as the –model argument. Noted that this name(s)will also be used in model_name tag content of prometheus metrics, if multiple names provided, metricstag will take the first one.

Async Engine Arguments#

Below are the additional arguments related to the asynchronous engine:

usage: -m vllm.entrypoints.openai.api_server [-h] [--engine-use-ray]
                                             [--disable-log-requests]
                                             [--max-log-len MAX_LOG_LEN]

Named Arguments#

--engine-use-ray

Use Ray to start the LLM engine in a separate process as the server process.

--disable-log-requests

Disable logging requests.

--max-log-len

Max number of prompt characters or prompt ID numbers being printed in log.

Default: Unlimited