Llava Example#
Source vllm-project/vllm.
1import argparse
2import os
3import subprocess
4
5import torch
6
7from vllm import LLM
8from vllm.sequence import MultiModalData
9
10# The assets are located at `s3://air-example-data-2/vllm_opensource_llava/`.
11
12
13def run_llava_pixel_values():
14 llm = LLM(
15 model="llava-hf/llava-1.5-7b-hf",
16 image_input_type="pixel_values",
17 image_token_id=32000,
18 image_input_shape="1,3,336,336",
19 image_feature_size=576,
20 )
21
22 prompt = "<image>" * 576 + (
23 "\nUSER: What is the content of this image?\nASSISTANT:")
24
25 # This should be provided by another online or offline component.
26 image = torch.load("images/stop_sign_pixel_values.pt")
27
28 outputs = llm.generate({
29 "prompt":
30 prompt,
31 "multi_modal_data":
32 MultiModalData(type=MultiModalData.Type.IMAGE, data=image),
33 })
34
35 for o in outputs:
36 generated_text = o.outputs[0].text
37 print(generated_text)
38
39
40def run_llava_image_features():
41 llm = LLM(
42 model="llava-hf/llava-1.5-7b-hf",
43 image_input_type="image_features",
44 image_token_id=32000,
45 image_input_shape="1,576,1024",
46 image_feature_size=576,
47 )
48
49 prompt = "<image>" * 576 + (
50 "\nUSER: What is the content of this image?\nASSISTANT:")
51
52 # This should be provided by another online or offline component.
53 image = torch.load("images/stop_sign_image_features.pt")
54
55 outputs = llm.generate({
56 "prompt":
57 prompt,
58 "multi_modal_data":
59 MultiModalData(type=MultiModalData.Type.IMAGE, data=image),
60 })
61 for o in outputs:
62 generated_text = o.outputs[0].text
63 print(generated_text)
64
65
66def main(args):
67 if args.type == "pixel_values":
68 run_llava_pixel_values()
69 else:
70 run_llava_image_features()
71
72
73if __name__ == "__main__":
74 parser = argparse.ArgumentParser(description="Demo on Llava")
75 parser.add_argument("--type",
76 type=str,
77 choices=["pixel_values", "image_features"],
78 default="pixel_values",
79 help="image input type")
80 args = parser.parse_args()
81 # Download from s3
82 s3_bucket_path = "s3://air-example-data-2/vllm_opensource_llava/"
83 local_directory = "images"
84
85 # Make sure the local directory exists or create it
86 os.makedirs(local_directory, exist_ok=True)
87
88 # Use AWS CLI to sync the directory, assume anonymous access
89 subprocess.check_call([
90 "aws",
91 "s3",
92 "sync",
93 s3_bucket_path,
94 local_directory,
95 "--no-sign-request",
96 ])
97 main(args)