Llava Example#
Source vllm-project/vllm.
1import argparse
2import os
3import subprocess
4
5import torch
6from PIL import Image
7
8from vllm import LLM
9from vllm.multimodal.image import ImageFeatureData, ImagePixelData
10
11# The assets are located at `s3://air-example-data-2/vllm_opensource_llava/`.
12# You can use `.buildkite/download-images.sh` to download them
13
14
15def run_llava_pixel_values(*, disable_image_processor: bool = False):
16 llm = LLM(
17 model="llava-hf/llava-1.5-7b-hf",
18 image_input_type="pixel_values",
19 image_token_id=32000,
20 image_input_shape="1,3,336,336",
21 image_feature_size=576,
22 disable_image_processor=disable_image_processor,
23 )
24
25 prompt = "<image>" * 576 + (
26 "\nUSER: What is the content of this image?\nASSISTANT:")
27
28 if disable_image_processor:
29 image = torch.load("images/stop_sign_pixel_values.pt")
30 else:
31 image = Image.open("images/stop_sign.jpg")
32
33 outputs = llm.generate({
34 "prompt": prompt,
35 "multi_modal_data": ImagePixelData(image),
36 })
37
38 for o in outputs:
39 generated_text = o.outputs[0].text
40 print(generated_text)
41
42
43def run_llava_image_features():
44 llm = LLM(
45 model="llava-hf/llava-1.5-7b-hf",
46 image_input_type="image_features",
47 image_token_id=32000,
48 image_input_shape="1,576,1024",
49 image_feature_size=576,
50 )
51
52 prompt = "<image>" * 576 + (
53 "\nUSER: What is the content of this image?\nASSISTANT:")
54
55 image: torch.Tensor = torch.load("images/stop_sign_image_features.pt")
56
57 outputs = llm.generate({
58 "prompt": prompt,
59 "multi_modal_data": ImageFeatureData(image),
60 })
61
62 for o in outputs:
63 generated_text = o.outputs[0].text
64 print(generated_text)
65
66
67def main(args):
68 if args.type == "pixel_values":
69 run_llava_pixel_values()
70 else:
71 run_llava_image_features()
72
73
74if __name__ == "__main__":
75 parser = argparse.ArgumentParser(description="Demo on Llava")
76 parser.add_argument("--type",
77 type=str,
78 choices=["pixel_values", "image_features"],
79 default="pixel_values",
80 help="image input type")
81 args = parser.parse_args()
82 # Download from s3
83 s3_bucket_path = "s3://air-example-data-2/vllm_opensource_llava/"
84 local_directory = "images"
85
86 # Make sure the local directory exists or create it
87 os.makedirs(local_directory, exist_ok=True)
88
89 # Use AWS CLI to sync the directory, assume anonymous access
90 subprocess.check_call([
91 "aws",
92 "s3",
93 "sync",
94 s3_bucket_path,
95 local_directory,
96 "--no-sign-request",
97 ])
98 main(args)