Source examples/offline_inference/batch_llm_inference.py.
Batch LLM Inference#
# SPDX-License-Identifier: Apache-2.0
"""
This example shows how to use Ray Data for data parallel batch inference.
Ray Data is a data processing framework that can handle large datasets
and integrates tightly with vLLM for data-parallel inference.
As of Ray 2.44, Ray Data has a native integration with
vLLM (under ray.data.llm).
Ray Data provides functionality for:
* Reading and writing to cloud storage (S3, GCS, etc.)
* Automatic sharding and load-balancing across a cluster
* Optimized configuration of vLLM using continuous batching
* Compatible with tensor/pipeline parallel inference as well.
Learn more about Ray Data's LLM integration:
https://docs.ray.io/en/latest/data/working-with-llms.html
"""
import ray
from packaging.version import Version
from ray.data.llm import build_llm_processor, vLLMEngineProcessorConfig
assert Version(ray.__version__) >= Version(
"2.44.1"), "Ray version must be at least 2.44.1"
# Uncomment to reduce clutter in stdout
# ray.init(log_to_driver=False)
# ray.data.DataContext.get_current().enable_progress_bars = False
# Read one text file from S3. Ray Data supports reading multiple files
# from cloud storage (such as JSONL, Parquet, CSV, binary format).
ds = ray.data.read_text("s3://anonymous@air-example-data/prompts.txt")
print(ds.schema())
size = ds.count()
print(f"Size of dataset: {size} prompts")
# Configure vLLM engine.
config = vLLMEngineProcessorConfig(
model_source="unsloth/Llama-3.1-8B-Instruct",
engine_kwargs={
"enable_chunked_prefill": True,
"max_num_batched_tokens": 4096,
"max_model_len": 16384,
},
concurrency=1, # set the number of parallel vLLM replicas
batch_size=64,
)
# Create a Processor object, which will be used to
# do batch inference on the dataset
vllm_processor = build_llm_processor(
config,
preprocess=lambda row: dict(
messages=[{
"role": "system",
"content": "You are a bot that responds with haikus."
}, {
"role": "user",
"content": row["text"]
}],
sampling_params=dict(
temperature=0.3,
max_tokens=250,
)),
postprocess=lambda row: dict(
answer=row["generated_text"],
**row # This will return all the original columns in the dataset.
),
)
ds = vllm_processor(ds)
# Peek first 10 results.
# NOTE: This is for local testing and debugging. For production use case,
# one should write full result out as shown below.
outputs = ds.take(limit=10)
for output in outputs:
prompt = output["prompt"]
generated_text = output["generated_text"]
print(f"Prompt: {prompt!r}")
print(f"Generated text: {generated_text!r}")
# Write inference output data out as Parquet files to S3.
# Multiple files would be written to the output destination,
# and each task would write one or more files separately.
#
# ds.write_parquet("s3://<your-output-bucket>")