Skip to content

Source examples/offline_inference/profiling_tpu.

vLLM TPU Profiling

This script is used to profile the TPU performance of vLLM for specific prefill or decode token shapes.

Note: an actual running server is a mix of both prefill of many shapes and decode of many shapes.

We assume you are on a TPU already (this was tested on TPU v6e) and have installed vLLM according to the installation guide.

In all examples below, we run several warmups before (so --enforce-eager is okay)

Profile Examples

Generate Prefill Trace

This example runs Qwen/Qwen2.5-7B-Instruct with a single request of 1024 input tokens. This is set up in attempt to profile just the prefill time and operations.

export XLA_HLO_DEBUG=1
export MODEL=Qwen/Qwen2.5-7B-Instruct
export VLLM_TPU_PROFILE_DURATION_MS=3000
export VLLM_TPU_PROFILE_DELAY_MS=0

python3 profiling.py \
    --model $MODEL \
    --input-len 1024 --output-len 1 \
    --batch-size 1 --enforce-eager \
    --max-model-len 2048 \
    --tensor-parallel-size 1 \
    --profile-result-dir profiles

Generate Decode Trace

This example runs Llama 3.1 70B with a batch of 32 requests where each has 1 input token and 128 output tokens. This is set up in attempt to profile just the 32 decodes running in parallel by having an extremely small prefill of 1 token and setting VLLM_TPU_PROFILE_DELAY_MS=1000 to skip the first second of inference (hopefully prefill).

export XLA_HLO_DEBUG=1
export MODEL=meta-llama/Llama-3.1-70B-Instruct
export VLLM_TPU_PROFILE_DURATION_MS=2000
export VLLM_TPU_PROFILE_DELAY_MS=1000

rm -rf ~/.cache/vllm/xla_cache
python3 profiling.py \
    --model $MODEL \
    --input-len 1 \
    --output-len 128 \
    --batch-size 32 \
    --enforce-eager \
    --profile-result-dir profiles \
    --max-model-len 2048 --tensor-parallel-size 8

Visualizing the profiles

Once you have collected your profiles with this script, you can visualize them using TensorBoard.

Here are most likely the dependencies you need to install:

pip install tensorflow-cpu tensorboard-plugin-profile etils importlib_resources

Then you just need to point TensorBoard to the directory where you saved the profiles and visit http://localhost:6006/ in your browser:

tensorboard --logdir profiles/ --port 6006

Example materials

profiling.py
# SPDX-License-Identifier: Apache-2.0
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project

import argparse
import dataclasses
import os
import time

import numpy as np
import torch_xla.debug.profiler as xp
from tqdm import tqdm

from vllm import LLM, SamplingParams
from vllm.engine.arg_utils import EngineArgs
from vllm.inputs import PromptType
from vllm.utils import FlexibleArgumentParser

DURATION_MS = int(os.getenv("VLLM_TPU_PROFILE_DURATION_MS", 3000))
DELAY_MS = int(os.getenv("VLLM_TPU_PROFILE_DELAY_MS", 0))


def main(args: argparse.Namespace):
    print(args)

    engine_args = EngineArgs.from_cli_args(args)
    llm = LLM(**dataclasses.asdict(engine_args))
    server = xp.start_server(9012)  # noqa: F841

    sampling_params = SamplingParams(
        temperature=0.0,
        ignore_eos=True,
        max_tokens=args.output_len,
    )
    print(sampling_params)
    dummy_prompt_token_ids = np.random.randint(
        10000, size=(args.batch_size, args.input_len)
    )
    dummy_prompts: list[PromptType] = [
        {"prompt_token_ids": batch} for batch in dummy_prompt_token_ids.tolist()
    ]

    def run_to_completion():
        start_time = time.perf_counter()
        llm.generate(dummy_prompts, sampling_params=sampling_params, use_tqdm=False)
        end_time = time.perf_counter()
        latency = end_time - start_time
        return latency

    # Warmup
    print("Warming up...")
    warmup_latencies = []
    for _ in tqdm(range(args.num_iters_warmup), desc="Warmup iterations"):
        warmup_latencies.append(run_to_completion())
    print(f"Average warmup latency: {np.mean(warmup_latencies):.4f}s")

    # Profile
    profile_dir = args.profile_result_dir
    print(f"Profiling (results will be saved to '{profile_dir}')...")
    # Enable tracing on server
    xp.trace_detached(
        "localhost:9012", profile_dir, delay_ms=DELAY_MS, duration_ms=DURATION_MS
    )
    if DELAY_MS == 0:
        time.sleep(1.0)
    profile_latencies = []
    for _ in tqdm(range(args.num_iters), desc="Profile iterations"):
        profile_latencies.append(run_to_completion())
    print(f"Average profile latency: {np.mean(profile_latencies):.4f}s")

    return


def parse_args():
    parser = FlexibleArgumentParser(
        description="Benchmark the latency of processing a single batch of "
        "requests till completion."
    )
    parser.add_argument("--input-len", type=int, default=32)
    parser.add_argument("--output-len", type=int, default=128)
    parser.add_argument("--batch-size", type=int, default=8)
    parser.add_argument(
        "--num-iters-warmup",
        type=int,
        default=5,
        help="Number of iterations to run for warmup.",
    )
    parser.add_argument(
        "--num-iters",
        type=int,
        default=1,
        help="Number of iterations to run for profiling.",
    )
    parser.add_argument(
        "--profile-result-dir",
        type=str,
        default="profiles",
        help=(
            "path to save the pytorch profiler output. Can be visualized "
            "with ui.perfetto.dev or Tensorboard "
            "(https://cloud.google.com/tpu/docs/pytorch-xla-performance-profiling-tpu-vm)."
        ),
    )

    parser = EngineArgs.add_cli_args(parser)
    return parser.parse_args()


if __name__ == "__main__":
    args = parse_args()
    main(args)