Offline Inference Distributed#

Source vllm-project/vllm.

  1"""
  2This example shows how to use Ray Data for running offline batch inference
  3distributively on a multi-nodes cluster.
  4
  5Learn more about Ray Data in https://docs.ray.io/en/latest/data/data.html
  6"""
  7
  8from typing import Dict
  9
 10import numpy as np
 11import ray
 12from packaging.version import Version
 13from ray.util.scheduling_strategies import PlacementGroupSchedulingStrategy
 14
 15from vllm import LLM, SamplingParams
 16
 17assert Version(ray.__version__) >= Version(
 18    "2.22.0"), "Ray version must be at least 2.22.0"
 19
 20# Create a sampling params object.
 21sampling_params = SamplingParams(temperature=0.8, top_p=0.95)
 22
 23# Set tensor parallelism per instance.
 24tensor_parallel_size = 1
 25
 26# Set number of instances. Each instance will use tensor_parallel_size GPUs.
 27num_instances = 1
 28
 29
 30# Create a class to do batch inference.
 31class LLMPredictor:
 32
 33    def __init__(self):
 34        # Create an LLM.
 35        self.llm = LLM(model="meta-llama/Llama-2-7b-chat-hf",
 36                       tensor_parallel_size=tensor_parallel_size)
 37
 38    def __call__(self, batch: Dict[str, np.ndarray]) -> Dict[str, list]:
 39        # Generate texts from the prompts.
 40        # The output is a list of RequestOutput objects that contain the prompt,
 41        # generated text, and other information.
 42        outputs = self.llm.generate(batch["text"], sampling_params)
 43        prompt = []
 44        generated_text = []
 45        for output in outputs:
 46            prompt.append(output.prompt)
 47            generated_text.append(' '.join([o.text for o in output.outputs]))
 48        return {
 49            "prompt": prompt,
 50            "generated_text": generated_text,
 51        }
 52
 53
 54# Read one text file from S3. Ray Data supports reading multiple files
 55# from cloud storage (such as JSONL, Parquet, CSV, binary format).
 56ds = ray.data.read_text("s3://anonymous@air-example-data/prompts.txt")
 57
 58
 59# For tensor_parallel_size > 1, we need to create placement groups for vLLM
 60# to use. Every actor has to have its own placement group.
 61def scheduling_strategy_fn():
 62    # One bundle per tensor parallel worker
 63    pg = ray.util.placement_group(
 64        [{
 65            "GPU": 1,
 66            "CPU": 1
 67        }] * tensor_parallel_size,
 68        strategy="STRICT_PACK",
 69    )
 70    return dict(scheduling_strategy=PlacementGroupSchedulingStrategy(
 71        pg, placement_group_capture_child_tasks=True))
 72
 73
 74resources_kwarg = {}
 75if tensor_parallel_size == 1:
 76    # For tensor_parallel_size == 1, we simply set num_gpus=1.
 77    resources_kwarg["num_gpus"] = 1
 78else:
 79    # Otherwise, we have to set num_gpus=0 and provide
 80    # a function that will create a placement group for
 81    # each instance.
 82    resources_kwarg["num_gpus"] = 0
 83    resources_kwarg["ray_remote_args_fn"] = scheduling_strategy_fn
 84
 85# Apply batch inference for all input data.
 86ds = ds.map_batches(
 87    LLMPredictor,
 88    # Set the concurrency to the number of LLM instances.
 89    concurrency=num_instances,
 90    # Specify the batch size for inference.
 91    batch_size=32,
 92    **resources_kwarg,
 93)
 94
 95# Peek first 10 results.
 96# NOTE: This is for local testing and debugging. For production use case,
 97# one should write full result out as shown below.
 98outputs = ds.take(limit=10)
 99for output in outputs:
100    prompt = output["prompt"]
101    generated_text = output["generated_text"]
102    print(f"Prompt: {prompt!r}, Generated text: {generated_text!r}")
103
104# Write inference output data out as Parquet files to S3.
105# Multiple files would be written to the output destination,
106# and each task would write one or more files separately.
107#
108# ds.write_parquet("s3://<your-output-bucket>")