Offline Inference Distributed#
Source vllm-project/vllm.
1"""
2This example shows how to use Ray Data for running offline batch inference
3distributively on a multi-nodes cluster.
4
5Learn more about Ray Data in https://docs.ray.io/en/latest/data/data.html
6"""
7
8from typing import Any, Dict, List
9
10import numpy as np
11import ray
12from packaging.version import Version
13from ray.util.scheduling_strategies import PlacementGroupSchedulingStrategy
14
15from vllm import LLM, SamplingParams
16
17assert Version(ray.__version__) >= Version(
18 "2.22.0"), "Ray version must be at least 2.22.0"
19
20# Create a sampling params object.
21sampling_params = SamplingParams(temperature=0.8, top_p=0.95)
22
23# Set tensor parallelism per instance.
24tensor_parallel_size = 1
25
26# Set number of instances. Each instance will use tensor_parallel_size GPUs.
27num_instances = 1
28
29
30# Create a class to do batch inference.
31class LLMPredictor:
32
33 def __init__(self):
34 # Create an LLM.
35 self.llm = LLM(model="meta-llama/Llama-2-7b-chat-hf",
36 tensor_parallel_size=tensor_parallel_size)
37
38 def __call__(self, batch: Dict[str, np.ndarray]) -> Dict[str, list]:
39 # Generate texts from the prompts.
40 # The output is a list of RequestOutput objects that contain the prompt,
41 # generated text, and other information.
42 outputs = self.llm.generate(batch["text"], sampling_params)
43 prompt: List[str] = []
44 generated_text: List[str] = []
45 for output in outputs:
46 prompt.append(output.prompt)
47 generated_text.append(' '.join([o.text for o in output.outputs]))
48 return {
49 "prompt": prompt,
50 "generated_text": generated_text,
51 }
52
53
54# Read one text file from S3. Ray Data supports reading multiple files
55# from cloud storage (such as JSONL, Parquet, CSV, binary format).
56ds = ray.data.read_text("s3://anonymous@air-example-data/prompts.txt")
57
58
59# For tensor_parallel_size > 1, we need to create placement groups for vLLM
60# to use. Every actor has to have its own placement group.
61def scheduling_strategy_fn():
62 # One bundle per tensor parallel worker
63 pg = ray.util.placement_group(
64 [{
65 "GPU": 1,
66 "CPU": 1
67 }] * tensor_parallel_size,
68 strategy="STRICT_PACK",
69 )
70 return dict(scheduling_strategy=PlacementGroupSchedulingStrategy(
71 pg, placement_group_capture_child_tasks=True))
72
73
74resources_kwarg: Dict[str, Any] = {}
75if tensor_parallel_size == 1:
76 # For tensor_parallel_size == 1, we simply set num_gpus=1.
77 resources_kwarg["num_gpus"] = 1
78else:
79 # Otherwise, we have to set num_gpus=0 and provide
80 # a function that will create a placement group for
81 # each instance.
82 resources_kwarg["num_gpus"] = 0
83 resources_kwarg["ray_remote_args_fn"] = scheduling_strategy_fn
84
85# Apply batch inference for all input data.
86ds = ds.map_batches(
87 LLMPredictor,
88 # Set the concurrency to the number of LLM instances.
89 concurrency=num_instances,
90 # Specify the batch size for inference.
91 batch_size=32,
92 **resources_kwarg,
93)
94
95# Peek first 10 results.
96# NOTE: This is for local testing and debugging. For production use case,
97# one should write full result out as shown below.
98outputs = ds.take(limit=10)
99for output in outputs:
100 prompt = output["prompt"]
101 generated_text = output["generated_text"]
102 print(f"Prompt: {prompt!r}, Generated text: {generated_text!r}")
103
104# Write inference output data out as Parquet files to S3.
105# Multiple files would be written to the output destination,
106# and each task would write one or more files separately.
107#
108# ds.write_parquet("s3://<your-output-bucket>")