Other AI accelerators#

vLLM is a Python library that supports the following AI accelerators. Select your AI accelerator type to see vendor specific instructions:

Tensor Processing Units (TPUs) are Google’s custom-developed application-specific integrated circuits (ASICs) used to accelerate machine learning workloads. TPUs are available in different versions each with different hardware specifications. For more information about TPUs, see TPU System Architecture. For more information on the TPU versions supported with vLLM, see:

These TPU versions allow you to configure the physical arrangements of the TPU chips. This can improve throughput and networking performance. For more information see:

In order for you to use Cloud TPUs you need to have TPU quota granted to your Google Cloud Platform project. TPU quotas specify how many TPUs you can use in a GPC project and are specified in terms of TPU version, the number of TPU you want to use, and quota type. For more information, see TPU quota.

For TPU pricing information, see Cloud TPU pricing.

You may need additional persistent storage for your TPU VMs. For more information, see Storage options for Cloud TPU data.

This tab provides instructions on running vLLM with Intel Gaudi devices.

vLLM 0.3.3 onwards supports model inferencing and serving on AWS Trainium/Inferentia with Neuron SDK with continuous batching. Paged Attention and Chunked Prefill are currently in development and will be available soon. Data types currently supported in Neuron SDK are FP16 and BF16.

vLLM powered by OpenVINO supports all LLM models from vLLM supported models list and can perform optimal model serving on all x86-64 CPUs with, at least, AVX2 support, as well as on both integrated and discrete Intel® GPUs (the list of supported GPUs).

Requirements#

  • Google Cloud TPU VM

  • TPU versions: v6e, v5e, v5p, v4

  • Python: 3.10 or newer

Provision Cloud TPUs

You can provision Cloud TPUs using the Cloud TPU API or the queued resources API. This section shows how to create TPUs using the queued resource API. For more information about using the Cloud TPU API, see Create a Cloud TPU using the Create Node API. Queued resources enable you to request Cloud TPU resources in a queued manner. When you request queued resources, the request is added to a queue maintained by the Cloud TPU service. When the requested resource becomes available, it’s assigned to your Google Cloud project for your immediate exclusive use.

Note

In all of the following commands, replace the ALL CAPS parameter names with appropriate values. See the parameter descriptions table for more information.

Provision Cloud TPUs with GKE

For more information about using TPUs with GKE, see:

  • OS: Ubuntu 22.04 LTS

  • Python: 3.10

  • Intel Gaudi accelerator

  • Intel Gaudi software version 1.18.0

Please follow the instructions provided in the Gaudi Installation Guide to set up the execution environment. To achieve the best performance, please follow the methods outlined in the Optimizing Training Platform Guide.

  • OS: Linux

  • Python: 3.9 – 3.11

  • Accelerator: NeuronCore_v2 (in trn1/inf2 instances)

  • Pytorch 2.0.1/2.1.1

  • AWS Neuron SDK 2.16/2.17 (Verified on python 3.8)

  • OS: Linux

  • Instruction set architecture (ISA) requirement: at least AVX2.

Configure a new environment#

Provision a Cloud TPU with the queued resource API

Create a TPU v5e with 4 TPU chips:

gcloud alpha compute tpus queued-resources create QUEUED_RESOURCE_ID \
--node-id TPU_NAME \
--project PROJECT_ID \
--zone ZONE \
--accelerator-type ACCELERATOR_TYPE \
--runtime-version RUNTIME_VERSION \
--service-account SERVICE_ACCOUNT
Parameter descriptions#

Parameter name

Description

QUEUED_RESOURCE_ID

The user-assigned ID of the queued resource request.

TPU_NAME

The user-assigned name of the TPU which is created when the queued resource request is allocated.

PROJECT_ID

Your Google Cloud project

ZONE

The GCP zone where you want to create your Cloud TPU. The value you use depends on the version of TPUs you are using. For more information, see TPU regions and zones <https://cloud.google.com/tpu/docs/regions-zones>_

ACCELERATOR_TYPE

The TPU version you want to use. Specify the TPU version, for example v5litepod-4 specifies a v5e TPU with 4 cores. For more information, see TPU versions <https://cloud.devsite.corp.google.com/tpu/docs/system-architecture-tpu-vm#versions>_.

RUNTIME_VERSION

The TPU VM runtime version to use. For more information see TPU VM images <https://cloud.google.com/tpu/docs/runtimes>_.

SERVICE_ACCOUNT

The email address for your service account. You can find it in the IAM Cloud Console under Service Accounts. For example: tpu-service-account@<your_project_ID>.iam.gserviceaccount.com

Connect to your TPU using SSH:

gcloud compute tpus tpu-vm ssh TPU_NAME --zone ZONE

Environment verification

To verify that the Intel Gaudi software was correctly installed, run:

hl-smi # verify that hl-smi is in your PATH and each Gaudi accelerator is visible
apt list --installed | grep habana # verify that habanalabs-firmware-tools, habanalabs-graph, habanalabs-rdma-core, habanalabs-thunk and habanalabs-container-runtime are installed
pip list | grep habana # verify that habana-torch-plugin, habana-torch-dataloader, habana-pyhlml and habana-media-loader are installed
pip list | grep neural # verify that neural_compressor is installed

Refer to Intel Gaudi Software Stack Verification for more details.

Run Docker Image

It is highly recommended to use the latest Docker image from Intel Gaudi vault. Refer to the Intel Gaudi documentation for more details.

Use the following commands to run a Docker image:

docker pull vault.habana.ai/gaudi-docker/1.18.0/ubuntu22.04/habanalabs/pytorch-installer-2.4.0:latest
docker run -it --runtime=habana -e HABANA_VISIBLE_DEVICES=all -e OMPI_MCA_btl_vader_single_copy_mechanism=none --cap-add=sys_nice --net=host --ipc=host vault.habana.ai/gaudi-docker/1.18.0/ubuntu22.04/habanalabs/pytorch-installer-2.4.0:latest

Launch Trn1/Inf2 instances

Here are the steps to launch trn1/inf2 instances, in order to install PyTorch Neuron (“torch-neuronx”) Setup on Ubuntu 22.04 LTS.

  • Please follow the instructions at launch an Amazon EC2 Instance to launch an instance. When choosing the instance type at the EC2 console, please make sure to select the correct instance type.

  • To get more information about instances sizes and pricing see: Trn1 web page, Inf2 web page

  • Select Ubuntu Server 22.04 TLS AMI

  • When launching a Trn1/Inf2, please adjust your primary EBS volume size to a minimum of 512GB.

  • After launching the instance, follow the instructions in Connect to your instance to connect to the instance

Install drivers and tools

The installation of drivers and tools wouldn’t be necessary, if Deep Learning AMI Neuron is installed. In case the drivers and tools are not installed on the operating system, follow the steps below:

# Configure Linux for Neuron repository updates
. /etc/os-release
sudo tee /etc/apt/sources.list.d/neuron.list > /dev/null <<EOF
deb https://apt.repos.neuron.amazonaws.com ${VERSION_CODENAME} main
EOF
wget -qO - https://apt.repos.neuron.amazonaws.com/GPG-PUB-KEY-AMAZON-AWS-NEURON.PUB | sudo apt-key add -

# Update OS packages
sudo apt-get update -y

# Install OS headers
sudo apt-get install linux-headers-$(uname -r) -y

# Install git
sudo apt-get install git -y

# install Neuron Driver
sudo apt-get install aws-neuronx-dkms=2.* -y

# Install Neuron Runtime
sudo apt-get install aws-neuronx-collectives=2.* -y
sudo apt-get install aws-neuronx-runtime-lib=2.* -y

# Install Neuron Tools
sudo apt-get install aws-neuronx-tools=2.* -y

# Add PATH
export PATH=/opt/aws/neuron/bin:$PATH

You can create a new Python environment using conda:

# (Recommended) Create a new conda environment.
conda create -n myenv python=3.12 -y
conda activate myenv

Note

PyTorch has deprecated the conda release channel. If you use conda, please only use it to create Python environment rather than installing packages.

Or you can create a new Python environment using uv, a very fast Python environment manager. Please follow the documentation to install uv. After installing uv, you can create a new Python environment using the following command:

# (Recommended) Create a new uv environment. Use `--seed` to install `pip` and `setuptools` in the environment.
uv venv myenv --python 3.12 --seed
source myenv/bin/activate

Set up using Python#

Pre-built wheels#

Currently, there are no pre-built TPU wheels.

Currently, there are no pre-built Intel Gaudi wheels.

Currently, there are no pre-built Neuron wheels.

Currently, there are no pre-built OpenVINO wheels.

Build wheel from source#

Install Miniconda:

wget https://repo.anaconda.com/miniconda/Miniconda3-latest-Linux-x86_64.sh
bash Miniconda3-latest-Linux-x86_64.sh
source ~/.bashrc

Create and activate a Conda environment for vLLM:

conda create -n vllm python=3.10 -y
conda activate vllm

Clone the vLLM repository and go to the vLLM directory:

git clone https://github.com/vllm-project/vllm.git && cd vllm

Uninstall the existing torch and torch_xla packages:

pip uninstall torch torch-xla -y

Install build dependencies:

pip install -r requirements-tpu.txt
sudo apt-get install libopenblas-base libopenmpi-dev libomp-dev

Run the setup script:

VLLM_TARGET_DEVICE="tpu" python setup.py develop

To build and install vLLM from source, run:

git clone https://github.com/vllm-project/vllm.git
cd vllm
python setup.py develop

Currently, the latest features and performance optimizations are developed in Gaudi’s vLLM-fork and we periodically upstream them to vLLM main repo. To install latest HabanaAI/vLLM-fork, run the following:

git clone https://github.com/HabanaAI/vllm-fork.git
cd vllm-fork
git checkout habana_main
python setup.py develop

Note

The currently supported version of Pytorch for Neuron installs triton version 2.1.0. This is incompatible with vllm >= 0.5.3. You may see an error cannot import name 'default_dump_dir.... To work around this, run a pip install --upgrade triton==3.0.0 after installing the vLLM wheel.

Following instructions are applicable to Neuron SDK 2.16 and beyond.

Install transformers-neuronx and its dependencies

transformers-neuronx will be the backend to support inference on trn1/inf2 instances. Follow the steps below to install transformer-neuronx package and its dependencies.

# Install Python venv
sudo apt-get install -y python3.10-venv g++

# Create Python venv
python3.10 -m venv aws_neuron_venv_pytorch

# Activate Python venv
source aws_neuron_venv_pytorch/bin/activate

# Install Jupyter notebook kernel
pip install ipykernel
python3.10 -m ipykernel install --user --name aws_neuron_venv_pytorch --display-name "Python (torch-neuronx)"
pip install jupyter notebook
pip install environment_kernels

# Set pip repository pointing to the Neuron repository
python -m pip config set global.extra-index-url https://pip.repos.neuron.amazonaws.com

# Install wget, awscli
python -m pip install wget
python -m pip install awscli

# Update Neuron Compiler and Framework
python -m pip install --upgrade neuronx-cc==2.* --pre torch-neuronx==2.1.* torchvision transformers-neuronx

Install vLLM from source

Once neuronx-cc and transformers-neuronx packages are installed, we will be able to install vllm as follows:

git clone https://github.com/vllm-project/vllm.git
cd vllm
pip install -U -r requirements-neuron.txt
VLLM_TARGET_DEVICE="neuron" pip install .

If neuron packages are detected correctly in the installation process, vllm-0.3.0+neuron212 will be installed.

First, install Python. For example, on Ubuntu 22.04, you can run:

sudo apt-get update  -y
sudo apt-get install python3

Second, install prerequisites vLLM OpenVINO backend installation:

pip install --upgrade pip
pip install -r requirements-build.txt --extra-index-url https://download.pytorch.org/whl/cpu

Finally, install vLLM with OpenVINO backend:

PIP_EXTRA_INDEX_URL="https://download.pytorch.org/whl/cpu" VLLM_TARGET_DEVICE=openvino python -m pip install -v .

Tip

To use vLLM OpenVINO backend with a GPU device, ensure your system is properly set up. Follow the instructions provided here: https://docs.openvino.ai/2024/get-started/configurations/configurations-intel-gpu.html.

Set up using Docker#

Pre-built images#

See Use vLLM’s Official Docker Image for instructions on using the official Docker image, making sure to substitute the image name vllm/vllm-openai with vllm/vllm-tpu.

Currently, there are no pre-built Intel Gaudi images.

Currently, there are no pre-built Neuron images.

Currently, there are no pre-built OpenVINO images.

Build image from source#

You can use Dockerfile.tpu to build a Docker image with TPU support.

docker build -f Dockerfile.tpu -t vllm-tpu .

Run the Docker image with the following command:

# Make sure to add `--privileged --net host --shm-size=16G`.
docker run --privileged --net host --shm-size=16G -it vllm-tpu

Note

Since TPU relies on XLA which requires static shapes, vLLM bucketizes the possible input shapes and compiles an XLA graph for each shape. The compilation time may take 20~30 minutes in the first run. However, the compilation time reduces to ~5 minutes afterwards because the XLA graphs are cached in the disk (in VLLM_XLA_CACHE_PATH or ~/.cache/vllm/xla_cache by default).

Tip

If you encounter the following error:

from torch._C import *  # noqa: F403
ImportError: libopenblas.so.0: cannot open shared object file: No such
file or directory

Install OpenBLAS with the following command:

$ sudo apt-get install libopenblas-base libopenmpi-dev libomp-dev
docker build -f Dockerfile.hpu -t vllm-hpu-env  .
docker run -it --runtime=habana -e HABANA_VISIBLE_DEVICES=all -e OMPI_MCA_btl_vader_single_copy_mechanism=none --cap-add=sys_nice --net=host --rm vllm-hpu-env

Tip

If you’re observing the following error: docker: Error response from daemon: Unknown runtime specified habana., please refer to “Install Using Containers” section of Intel Gaudi Software Stack and Driver Installation. Make sure you have habana-container-runtime package installed and that habana container runtime is registered.

See Building vLLM’s Docker Image from Source for instructions on building the Docker image.

Make sure to use Dockerfile.neuron in place of the default Dockerfile.

docker build -f Dockerfile.openvino -t vllm-openvino-env .
docker run -it --rm vllm-openvino-env

Extra information#

There is no extra information for this device.

Supported features

  • Offline inference

  • Online serving via OpenAI-Compatible Server

  • HPU autodetection - no need to manually select device within vLLM

  • Paged KV cache with algorithms enabled for Intel Gaudi accelerators

  • Custom Intel Gaudi implementations of Paged Attention, KV cache ops, prefill attention, Root Mean Square Layer Normalization, Rotary Positional Encoding

  • Tensor parallelism support for multi-card inference

  • Inference with HPU Graphs for accelerating low-batch latency and throughput

  • Attention with Linear Biases (ALiBi)

Unsupported features

  • Beam search

  • LoRA adapters

  • Quantization

  • Prefill chunking (mixed-batch inferencing)

Supported configurations

The following configurations have been validated to be function with Gaudi2 devices. Configurations that are not listed may or may not work.

Performance tuning

Execution modes

Currently in vLLM for HPU we support four execution modes, depending on selected HPU PyTorch Bridge backend (via PT_HPU_LAZY_MODE environment variable), and --enforce-eager flag.

vLLM execution modes#

PT_HPU_LAZY_MODE

enforce_eager

execution mode

0

0

torch.compile

0

1

PyTorch eager mode

1

0

HPU Graphs

1

1

PyTorch lazy mode

Warning

In 1.18.0, all modes utilizing PT_HPU_LAZY_MODE=0 are highly experimental and should be only used for validating functional correctness. Their performance will be improved in the next releases. For obtaining the best performance in 1.18.0, please use HPU Graphs, or PyTorch lazy mode.

Bucketing mechanism

Intel Gaudi accelerators work best when operating on models with fixed tensor shapes. Intel Gaudi Graph Compiler is responsible for generating optimized binary code that implements the given model topology on Gaudi. In its default configuration, the produced binary code may be heavily dependent on input and output tensor shapes, and can require graph recompilation when encountering differently shaped tensors within the same topology. While the resulting binaries utilize Gaudi efficiently, the compilation itself may introduce a noticeable overhead in end-to-end execution. In a dynamic inference serving scenario, there is a need to minimize the number of graph compilations and reduce the risk of graph compilation occurring during server runtime. Currently it is achieved by “bucketing” model’s forward pass across two dimensions - batch_size and sequence_length.

Note

Bucketing allows us to reduce the number of required graphs significantly, but it does not handle any graph compilation and device code generation - this is done in warmup and HPUGraph capture phase.

Bucketing ranges are determined with 3 parameters - min, step and max. They can be set separately for prompt and decode phase, and for batch size and sequence length dimension. These parameters can be observed in logs during vLLM startup:

INFO 08-01 21:37:59 hpu_model_runner.py:493] Prompt bucket config (min, step, max_warmup) bs:[1, 32, 4], seq:[128, 128, 1024]
INFO 08-01 21:37:59 hpu_model_runner.py:499] Generated 24 prompt buckets: [(1, 128), (1, 256), (1, 384), (1, 512), (1, 640), (1, 768), (1, 896), (1, 1024), (2, 128), (2, 256), (2, 384), (2, 512), (2, 640), (2, 768), (2, 896), (2, 1024), (4, 128), (4, 256), (4, 384), (4, 512), (4, 640), (4, 768), (4, 896), (4, 1024)]
INFO 08-01 21:37:59 hpu_model_runner.py:504] Decode bucket config (min, step, max_warmup) bs:[1, 128, 4], seq:[128, 128, 2048]
INFO 08-01 21:37:59 hpu_model_runner.py:509] Generated 48 decode buckets: [(1, 128), (1, 256), (1, 384), (1, 512), (1, 640), (1, 768), (1, 896), (1, 1024), (1, 1152), (1, 1280), (1, 1408), (1, 1536), (1, 1664), (1, 1792), (1, 1920), (1, 2048), (2, 128), (2, 256), (2, 384), (2, 512), (2, 640), (2, 768), (2, 896), (2, 1024), (2, 1152), (2, 1280), (2, 1408), (2, 1536), (2, 1664), (2, 1792), (2, 1920), (2, 2048), (4, 128), (4, 256), (4, 384), (4, 512), (4, 640), (4, 768), (4, 896), (4, 1024), (4, 1152), (4, 1280), (4, 1408), (4, 1536), (4, 1664), (4, 1792), (4, 1920), (4, 2048)]

min determines the lowest value of the bucket. step determines the interval between buckets, and max determines the upper bound of the bucket. Furthermore, interval between min and step has special handling – min gets multiplied by consecutive powers of two, until step gets reached. We call this the ramp-up phase and it is used for handling lower batch sizes with minimum wastage, while allowing larger padding on larger batch sizes.

Example (with ramp-up)

min = 2, step = 32, max = 64
=> ramp_up = (2, 4, 8, 16)
=> stable = (32, 64)
=> buckets = ramp_up + stable => (2, 4, 8, 16, 32, 64)

Example (without ramp-up)

min = 128, step = 128, max = 512
=> ramp_up = ()
=> stable = (128, 256, 384, 512)
=> buckets = ramp_up + stable => (128, 256, 384, 512)

In the logged scenario, 24 buckets were generated for prompt (prefill) runs, and 48 buckets for decode runs. Each bucket corresponds to a separate optimized device binary for a given model with specified tensor shapes. Whenever a batch of requests is processed, it is padded across batch and sequence length dimension to the smallest possible bucket.

Warning

If a request exceeds maximum bucket size in any dimension, it will be processed without padding, and its processing may require a graph compilation, potentially significantly increasing end-to-end latency. The boundaries of the buckets are user-configurable via environment variables, and upper bucket boundaries can be increased to avoid such scenario.

As an example, if a request of 3 sequences, with max sequence length of 412 comes in to an idle vLLM server, it will be padded executed as (4, 512) prefill bucket, as batch_size (number of sequences) will be padded to 4 (closest batch_size dimension higher than 3), and max sequence length will be padded to 512 (closest sequence length dimension higher than 412). After prefill stage, it will be executed as (4, 512) decode bucket and will continue as that bucket until either batch dimension changes (due to request being finished) - in which case it will become a (2, 512) bucket, or context length increases above 512 tokens, in which case it will become (4, 640) bucket.

Note

Bucketing is transparent to a client – padding in sequence length dimension is never returned to the client, and padding in batch dimension does not create new requests.

Warmup

Warmup is an optional, but highly recommended step occurring before vLLM server starts listening. It executes a forward pass for each bucket with dummy data. The goal is to pre-compile all graphs and not incur any graph compilation overheads within bucket boundaries during server runtime. Each warmup step is logged during vLLM startup:

INFO 08-01 22:26:47 hpu_model_runner.py:1066] [Warmup][Prompt][1/24] batch_size:4 seq_len:1024 free_mem:79.16 GiB
INFO 08-01 22:26:47 hpu_model_runner.py:1066] [Warmup][Prompt][2/24] batch_size:4 seq_len:896 free_mem:55.43 GiB
INFO 08-01 22:26:48 hpu_model_runner.py:1066] [Warmup][Prompt][3/24] batch_size:4 seq_len:768 free_mem:55.43 GiB
...
INFO 08-01 22:26:59 hpu_model_runner.py:1066] [Warmup][Prompt][24/24] batch_size:1 seq_len:128 free_mem:55.43 GiB
INFO 08-01 22:27:00 hpu_model_runner.py:1066] [Warmup][Decode][1/48] batch_size:4 seq_len:2048 free_mem:55.43 GiB
INFO 08-01 22:27:00 hpu_model_runner.py:1066] [Warmup][Decode][2/48] batch_size:4 seq_len:1920 free_mem:55.43 GiB
INFO 08-01 22:27:01 hpu_model_runner.py:1066] [Warmup][Decode][3/48] batch_size:4 seq_len:1792 free_mem:55.43 GiB
...
INFO 08-01 22:27:16 hpu_model_runner.py:1066] [Warmup][Decode][47/48] batch_size:2 seq_len:128 free_mem:55.43 GiB
INFO 08-01 22:27:16 hpu_model_runner.py:1066] [Warmup][Decode][48/48] batch_size:1 seq_len:128 free_mem:55.43 GiB

This example uses the same buckets as in the Bucketing Mechanism section. Each output line corresponds to execution of a single bucket. When bucket is executed for the first time, its graph is compiled and can be reused later on, skipping further graph compilations.

Tip

Compiling all the buckets might take some time and can be turned off with VLLM_SKIP_WARMUP=true environment variable. Keep in mind that if you do that, you may face graph compilations once executing a given bucket for the first time. It is fine to disable warmup for development, but it’s highly recommended to enable it in deployment.

HPU Graph capture

HPU Graphs are currently the most performant execution method of vLLM on Intel Gaudi. When HPU Graphs are enabled, execution graphs will be traced (recorded) ahead of time (after performing warmup), to be later replayed during inference, significantly reducing host overheads. Recording can take large amounts of memory, which needs to be taken into account when allocating KV cache. Enabling HPU Graphs will impact the number of available KV cache blocks, but vLLM provides user-configurable variables to control memory management.

When HPU Graphs are being used, they share the common memory pool (“usable memory”) as KV cache, determined by gpu_memory_utilization flag (0.9 by default). Before KV cache gets allocated, model weights are loaded onto the device, and a forward pass of the model is executed on dummy data, to estimate memory usage. Only after that, gpu_memory_utilization flag is utilized - at its default value, will mark 90% of free device memory at that point as usable. Next, KV cache gets allocated, model is warmed up, and HPU Graphs are captured. Environment variable VLLM_GRAPH_RESERVED_MEM defines the ratio of memory reserved for HPU Graphs capture. With its default value (VLLM_GRAPH_RESERVED_MEM=0.1), 10% of usable memory will be reserved for graph capture (later referred to as “usable graph memory”), and the remaining 90% will be utilized for KV cache. Environment variable VLLM_GRAPH_PROMPT_RATIO determines the ratio of usable graph memory reserved for prefill and decode graphs. By default (VLLM_GRAPH_PROMPT_RATIO=0.3), both stages have equal memory constraints. Lower value corresponds to less usable graph memory reserved for prefill stage, e.g. VLLM_GRAPH_PROMPT_RATIO=0.2 will reserve 20% of usable graph memory for prefill graphs, and 80% of usable graph memory for decode graphs.

Note

gpu_memory_utilization does not correspond to the absolute memory usage across HPU. It specifies the memory margin after loading the model and performing a profile run. If device has 100 GiB of total memory, and 50 GiB of free memory after loading model weights and executing profiling run, gpu_memory_utilization at its default value will mark 90% of 50 GiB as usable, leaving 5 GiB of margin, regardless of total device memory.

User can also configure the strategy for capturing HPU Graphs for prompt and decode stages separately. Strategy affects the order of capturing graphs. There are two strategies implemented: - max_bs - graph capture queue will sorted in descending order by their batch sizes. Buckets with equal batch sizes are sorted by sequence length in ascending order (e.g. (64, 128), (64, 256), (32, 128), (32, 256), (1, 128), (1,256)), default strategy for decode - min_tokens - graph capture queue will be sorted in ascending order by the number of tokens each graph processes (batch_size*sequence_length), default strategy for prompt

When there’s large amount of requests pending, vLLM scheduler will attempt to fill the maximum batch size for decode as soon as possible. When a request is finished, decode batch size decreases. When that happens, vLLM will attempt to schedule a prefill iteration for requests in the waiting queue, to fill the decode batch size to its previous state. This means that in a full load scenario, decode batch size is often at its maximum, which makes large batch size HPU Graphs crucial to capture, as reflected by max_bs strategy. On the other hand, prefills will be executed most frequently with very low batch sizes (1-4), which is reflected in min_tokens strategy.

Note

VLLM_GRAPH_PROMPT_RATIO does not set a hard limit on memory taken by graphs for each stage (prefill and decode). vLLM will first attempt to use up entirety of usable prefill graph memory (usable graph memory * VLLM_GRAPH_PROMPT_RATIO) for capturing prefill HPU Graphs, next it will attempt do the same for decode graphs and usable decode graph memory pool. If one stage is fully captured, and there is unused memory left within usable graph memory pool, vLLM will attempt further graph capture for the other stage, until no more HPU Graphs can be captured without exceeding reserved memory pool. The behavior on that mechanism can be observed in the example below.

Each described step is logged by vLLM server, as follows (negative values correspond to memory being released):

INFO 08-02 17:37:44 hpu_model_runner.py:493] Prompt bucket config (min, step, max_warmup) bs:[1, 32, 4], seq:[128, 128, 1024]
INFO 08-02 17:37:44 hpu_model_runner.py:499] Generated 24 prompt buckets: [(1, 128), (1, 256), (1, 384), (1, 512), (1, 640), (1, 768), (1, 896), (1, 1024), (2, 128), (2, 256), (2, 384), (2, 512), (2, 640), (2, 768), (2, 896), (2, 1024), (4, 128), (4, 256), (4, 384), (4, 512), (4, 640), (4, 768), (4, 896), (4, 1024)]
INFO 08-02 17:37:44 hpu_model_runner.py:504] Decode bucket config (min, step, max_warmup) bs:[1, 128, 4], seq:[128, 128, 2048]
INFO 08-02 17:37:44 hpu_model_runner.py:509] Generated 48 decode buckets: [(1, 128), (1, 256), (1, 384), (1, 512), (1, 640), (1, 768), (1, 896), (1, 1024), (1, 1152), (1, 1280), (1, 1408), (1, 1536), (1, 1664), (1, 1792), (1, 1920), (1, 2048), (2, 128), (2, 256), (2, 384), (2, 512), (2, 640), (2, 768), (2, 896), (2, 1024), (2, 1152), (2, 1280), (2, 1408), (2, 1536), (2, 1664), (2, 1792), (2, 1920), (2, 2048), (4, 128), (4, 256), (4, 384), (4, 512), (4, 640), (4, 768), (4, 896), (4, 1024), (4, 1152), (4, 1280), (4, 1408), (4, 1536), (4, 1664), (4, 1792), (4, 1920), (4, 2048)]
INFO 08-02 17:37:52 hpu_model_runner.py:430] Pre-loading model weights on hpu:0 took 14.97 GiB of device memory (14.97 GiB/94.62 GiB used) and 2.95 GiB of host memory (475.2 GiB/1007 GiB used)
INFO 08-02 17:37:52 hpu_model_runner.py:438] Wrapping in HPU Graph took 0 B of device memory (14.97 GiB/94.62 GiB used) and -252 KiB of host memory (475.2 GiB/1007 GiB used)
INFO 08-02 17:37:52 hpu_model_runner.py:442] Loading model weights took in total 14.97 GiB of device memory (14.97 GiB/94.62 GiB used) and 2.95 GiB of host memory (475.2 GiB/1007 GiB used)
INFO 08-02 17:37:54 hpu_worker.py:134] Model profiling run took 504 MiB of device memory (15.46 GiB/94.62 GiB used) and 180.9 MiB of host memory (475.4 GiB/1007 GiB used)
INFO 08-02 17:37:54 hpu_worker.py:158] Free device memory: 79.16 GiB, 39.58 GiB usable (gpu_memory_utilization=0.5), 15.83 GiB reserved for HPUGraphs (VLLM_GRAPH_RESERVED_MEM=0.4), 23.75 GiB reserved for KV cache
INFO 08-02 17:37:54 hpu_executor.py:85] # HPU blocks: 1519, # CPU blocks: 0
INFO 08-02 17:37:54 hpu_worker.py:190] Initializing cache engine took 23.73 GiB of device memory (39.2 GiB/94.62 GiB used) and -1.238 MiB of host memory (475.4 GiB/1007 GiB used)
INFO 08-02 17:37:54 hpu_model_runner.py:1066] [Warmup][Prompt][1/24] batch_size:4 seq_len:1024 free_mem:55.43 GiB
...
INFO 08-02 17:38:22 hpu_model_runner.py:1066] [Warmup][Decode][48/48] batch_size:1 seq_len:128 free_mem:55.43 GiB
INFO 08-02 17:38:22 hpu_model_runner.py:1159] Using 15.85 GiB/55.43 GiB of free device memory for HPUGraphs, 7.923 GiB for prompt and 7.923 GiB for decode (VLLM_GRAPH_PROMPT_RATIO=0.3)
INFO 08-02 17:38:22 hpu_model_runner.py:1066] [Warmup][Graph/Prompt][1/24] batch_size:1 seq_len:128 free_mem:55.43 GiB
...
INFO 08-02 17:38:26 hpu_model_runner.py:1066] [Warmup][Graph/Prompt][11/24] batch_size:1 seq_len:896 free_mem:48.77 GiB
INFO 08-02 17:38:27 hpu_model_runner.py:1066] [Warmup][Graph/Decode][1/48] batch_size:4 seq_len:128 free_mem:47.51 GiB
...
INFO 08-02 17:38:41 hpu_model_runner.py:1066] [Warmup][Graph/Decode][48/48] batch_size:1 seq_len:2048 free_mem:47.35 GiB
INFO 08-02 17:38:41 hpu_model_runner.py:1066] [Warmup][Graph/Prompt][12/24] batch_size:4 seq_len:256 free_mem:47.35 GiB
INFO 08-02 17:38:42 hpu_model_runner.py:1066] [Warmup][Graph/Prompt][13/24] batch_size:2 seq_len:512 free_mem:45.91 GiB
INFO 08-02 17:38:42 hpu_model_runner.py:1066] [Warmup][Graph/Prompt][14/24] batch_size:1 seq_len:1024 free_mem:44.48 GiB
INFO 08-02 17:38:43 hpu_model_runner.py:1066] [Warmup][Graph/Prompt][15/24] batch_size:2 seq_len:640 free_mem:43.03 GiB
INFO 08-02 17:38:43 hpu_model_runner.py:1128] Graph/Prompt captured:15 (62.5%) used_mem:14.03 GiB buckets:[(1, 128), (1, 256), (1, 384), (1, 512), (1, 640), (1, 768), (1, 896), (1, 1024), (2, 128), (2, 256), (2, 384), (2, 512), (2, 640), (4, 128), (4, 256)]
INFO 08-02 17:38:43 hpu_model_runner.py:1128] Graph/Decode captured:48 (100.0%) used_mem:161.9 MiB buckets:[(1, 128), (1, 256), (1, 384), (1, 512), (1, 640), (1, 768), (1, 896), (1, 1024), (1, 1152), (1, 1280), (1, 1408), (1, 1536), (1, 1664), (1, 1792), (1, 1920), (1, 2048), (2, 128), (2, 256), (2, 384), (2, 512), (2, 640), (2, 768), (2, 896), (2, 1024), (2, 1152), (2, 1280), (2, 1408), (2, 1536), (2, 1664), (2, 1792), (2, 1920), (2, 2048), (4, 128), (4, 256), (4, 384), (4, 512), (4, 640), (4, 768), (4, 896), (4, 1024), (4, 1152), (4, 1280), (4, 1408), (4, 1536), (4, 1664), (4, 1792), (4, 1920), (4, 2048)]
INFO 08-02 17:38:43 hpu_model_runner.py:1206] Warmup finished in 49 secs, allocated 14.19 GiB of device memory
INFO 08-02 17:38:43 hpu_executor.py:91] init_cache_engine took 37.92 GiB of device memory (53.39 GiB/94.62 GiB used) and 57.86 MiB of host memory (475.4 GiB/1007 GiB used)
  • We recommend running inference on Gaudi 2 with block_size of 128 for BF16 data type. Using default values (16, 32) might lead to sub-optimal performance due to Matrix Multiplication Engine under-utilization (see Gaudi Architecture).

  • For max throughput on Llama 7B, we recommend running with batch size of 128 or 256 and max context length of 2048 with HPU Graphs enabled. If you encounter out-of-memory issues, see troubleshooting section.

Environment variables

Diagnostic and profiling knobs:

  • VLLM_PROFILER_ENABLED: if true, high level profiler will be enabled. Resulting JSON traces can be viewed in perfetto.habana.ai. Disabled by default.

  • VLLM_HPU_LOG_STEP_GRAPH_COMPILATION: if true, will log graph compilations per each vLLM engine step, only when there was any - highly recommended to use alongside PT_HPU_METRICS_GC_DETAILS=1. Disabled by default.

  • VLLM_HPU_LOG_STEP_GRAPH_COMPILATION_ALL: if true, will log graph compilations per each vLLM engine step, always, even if there were none. Disabled by default.

  • VLLM_HPU_LOG_STEP_CPU_FALLBACKS: if true, will log cpu fallbacks per each vLLM engine step, only when there was any. Disabled by default.

  • VLLM_HPU_LOG_STEP_CPU_FALLBACKS_ALL: if true, will log cpu fallbacks per each vLLM engine step, always, even if there were none. Disabled by default.

Performance tuning knobs:

  • VLLM_SKIP_WARMUP: if true, warmup will be skipped, false by default

  • VLLM_GRAPH_RESERVED_MEM: percentage of memory dedicated for HPUGraph capture, 0.1 by default

  • VLLM_GRAPH_PROMPT_RATIO: percentage of reserved graph memory dedicated for prompt graphs, 0.3 by default

  • VLLM_GRAPH_PROMPT_STRATEGY: strategy determining order of prompt graph capture, min_tokens or max_bs, min_tokens by default

  • VLLM_GRAPH_DECODE_STRATEGY: strategy determining order of decode graph capture, min_tokens or max_bs, max_bs by default

  • VLLM_{phase}_{dim}_BUCKET_{param} - collection of 12 environment variables configuring ranges of bucketing mechanism

    • {phase} is either PROMPT or DECODE

    • {dim} is either BS, SEQ or BLOCK

    • {param} is either MIN, STEP or MAX

    • Default values:

      • Prompt:

        • batch size min (VLLM_PROMPT_BS_BUCKET_MIN): 1

        • batch size step (VLLM_PROMPT_BS_BUCKET_STEP): min(max_num_seqs, 32)

        • batch size max (VLLM_PROMPT_BS_BUCKET_MAX): min(max_num_seqs, 64)

        • sequence length min (VLLM_PROMPT_SEQ_BUCKET_MIN): block_size

        • sequence length step (VLLM_PROMPT_SEQ_BUCKET_STEP): block_size

        • sequence length max (VLLM_PROMPT_SEQ_BUCKET_MAX): max_model_len

      • Decode:

        • batch size min (VLLM_DECODE_BS_BUCKET_MIN): 1

        • batch size step (VLLM_DECODE_BS_BUCKET_STEP): min(max_num_seqs, 32)

        • batch size max (VLLM_DECODE_BS_BUCKET_MAX): max_num_seqs

        • sequence length min (VLLM_DECODE_BLOCK_BUCKET_MIN): block_size

        • sequence length step (VLLM_DECODE_BLOCK_BUCKET_STEP): block_size

        • sequence length max (VLLM_DECODE_BLOCK_BUCKET_MAX): max(128, (max_num_seqs*max_model_len)/block_size)

Additionally, there are HPU PyTorch Bridge environment variables impacting vLLM execution:

  • PT_HPU_LAZY_MODE: if 0, PyTorch Eager backend for Gaudi will be used, if 1 PyTorch Lazy backend for Gaudi will be used, 1 is default

  • PT_HPU_ENABLE_LAZY_COLLECTIVES: required to be true for tensor parallel inference with HPU Graphs

Troubleshooting: tweaking HPU graphs

If you experience device out-of-memory issues or want to attempt inference at higher batch sizes, try tweaking HPU Graphs by following the below:

  • Tweak gpu_memory_utilization knob. It will decrease the allocation of KV cache, leaving some headroom for capturing graphs with larger batch size. By default gpu_memory_utilization is set to 0.9. It attempts to allocate ~90% of HBM left for KV cache after short profiling run. Note that decreasing reduces the number of KV cache blocks you have available, and therefore reduces the effective maximum number of tokens you can handle at a given time.

  • If this method is not efficient, you can disable HPUGraph completely. With HPU Graphs disabled, you are trading latency and throughput at lower batches for potentially higher throughput on higher batches. You can do that by adding --enforce-eager flag to server (for online serving), or by passing enforce_eager=True argument to LLM constructor (for offline inference).

There is no extra information for this device.

Supported features

OpenVINO vLLM backend supports the following advanced vLLM features:

  • Prefix caching (--enable-prefix-caching)

  • Chunked prefill (--enable-chunked-prefill)

Performance tips

vLLM OpenVINO backend environment variables

  • VLLM_OPENVINO_DEVICE to specify which device utilize for the inference. If there are multiple GPUs in the system, additional indexes can be used to choose the proper one (e.g, VLLM_OPENVINO_DEVICE=GPU.1). If the value is not specified, CPU device is used by default.

  • VLLM_OPENVINO_ENABLE_QUANTIZED_WEIGHTS=ON to enable U8 weights compression during model loading stage. By default, compression is turned off. You can also export model with different compression techniques using optimum-cli and pass exported folder as <model_id>

CPU performance tips

CPU uses the following environment variables to control behavior:

  • VLLM_OPENVINO_KVCACHE_SPACE to specify the KV Cache size (e.g, VLLM_OPENVINO_KVCACHE_SPACE=40 means 40 GB space for KV cache), larger setting will allow vLLM running more requests in parallel. This parameter should be set based on the hardware configuration and memory management pattern of users.

  • VLLM_OPENVINO_CPU_KV_CACHE_PRECISION=u8 to control KV cache precision. By default, FP16 / BF16 is used depending on platform.

To enable better TPOT / TTFT latency, you can use vLLM’s chunked prefill feature (--enable-chunked-prefill). Based on the experiments, the recommended batch size is 256 (--max-num-batched-tokens)

OpenVINO best known configuration for CPU is:

$ VLLM_OPENVINO_KVCACHE_SPACE=100 VLLM_OPENVINO_CPU_KV_CACHE_PRECISION=u8 VLLM_OPENVINO_ENABLE_QUANTIZED_WEIGHTS=ON \
    python3 vllm/benchmarks/benchmark_throughput.py --model meta-llama/Llama-2-7b-chat-hf --dataset vllm/benchmarks/ShareGPT_V3_unfiltered_cleaned_split.json --enable-chunked-prefill --max-num-batched-tokens 256

GPU performance tips

GPU device implements the logic for automatic detection of available GPU memory and, by default, tries to reserve as much memory as possible for the KV cache (taking into account gpu_memory_utilization option). However, this behavior can be overridden by explicitly specifying the desired amount of memory for the KV cache using VLLM_OPENVINO_KVCACHE_SPACE environment variable (e.g, VLLM_OPENVINO_KVCACHE_SPACE=8 means 8 GB space for KV cache).

Currently, the best performance using GPU can be achieved with the default vLLM execution parameters for models with quantized weights (8 and 4-bit integer data types are supported) and preemption-mode=swap.

OpenVINO best known configuration for GPU is:

$ VLLM_OPENVINO_DEVICE=GPU VLLM_OPENVINO_ENABLE_QUANTIZED_WEIGHTS=ON \
    python3 vllm/benchmarks/benchmark_throughput.py --model meta-llama/Llama-2-7b-chat-hf --dataset vllm/benchmarks/ShareGPT_V3_unfiltered_cleaned_split.json

Limitations

  • LoRA serving is not supported.

  • Only LLM models are currently supported. LLaVa and encoder-decoder models are not currently enabled in vLLM OpenVINO integration.

  • Tensor and pipeline parallelism are not currently enabled in vLLM integration.