Offline Inference Distributed#
Source vllm-project/vllm.
1"""
2This example shows how to use Ray Data for running offline batch inference
3distributively on a multi-nodes cluster.
4
5Learn more about Ray Data in https://docs.ray.io/en/latest/data/data.html
6"""
7
8from typing import Dict
9
10import numpy as np
11import ray
12
13from vllm import LLM, SamplingParams
14
15# Create a sampling params object.
16sampling_params = SamplingParams(temperature=0.8, top_p=0.95)
17
18
19# Create a class to do batch inference.
20class LLMPredictor:
21
22 def __init__(self):
23 # Create an LLM.
24 self.llm = LLM(model="meta-llama/Llama-2-7b-chat-hf")
25
26 def __call__(self, batch: Dict[str, np.ndarray]) -> Dict[str, list]:
27 # Generate texts from the prompts.
28 # The output is a list of RequestOutput objects that contain the prompt,
29 # generated text, and other information.
30 outputs = self.llm.generate(batch["text"], sampling_params)
31 prompt = []
32 generated_text = []
33 for output in outputs:
34 prompt.append(output.prompt)
35 generated_text.append(' '.join([o.text for o in output.outputs]))
36 return {
37 "prompt": prompt,
38 "generated_text": generated_text,
39 }
40
41
42# Read one text file from S3. Ray Data supports reading multiple files
43# from cloud storage (such as JSONL, Parquet, CSV, binary format).
44ds = ray.data.read_text("s3://anonymous@air-example-data/prompts.txt")
45
46# Apply batch inference for all input data.
47ds = ds.map_batches(
48 LLMPredictor,
49 # Set the concurrency to the number of LLM instances.
50 concurrency=10,
51 # Specify the number of GPUs required per LLM instance.
52 # NOTE: Do NOT set `num_gpus` when using vLLM with tensor-parallelism
53 # (i.e., `tensor_parallel_size`).
54 num_gpus=1,
55 # Specify the batch size for inference.
56 batch_size=32,
57)
58
59# Peek first 10 results.
60# NOTE: This is for local testing and debugging. For production use case,
61# one should write full result out as shown below.
62outputs = ds.take(limit=10)
63for output in outputs:
64 prompt = output["prompt"]
65 generated_text = output["generated_text"]
66 print(f"Prompt: {prompt!r}, Generated text: {generated_text!r}")
67
68# Write inference output data out as Parquet files to S3.
69# Multiple files would be written to the output destination,
70# and each task would write one or more files separately.
71#
72# ds.write_parquet("s3://<your-output-bucket>")