Llava Example#
Source vllm-project/vllm.
1import argparse
2import os
3import subprocess
4
5import torch
6
7from vllm import LLM
8from vllm.sequence import MultiModalData
9
10# The assets are located at `s3://air-example-data-2/vllm_opensource_llava/`.
11
12
13def run_llava_pixel_values():
14 llm = LLM(
15 model="llava-hf/llava-1.5-7b-hf",
16 image_input_type="pixel_values",
17 image_token_id=32000,
18 image_input_shape="1,3,336,336",
19 image_feature_size=576,
20 )
21
22 prompt = "<image>" * 576 + (
23 "\nUSER: What is the content of this image?\nASSISTANT:")
24
25 # This should be provided by another online or offline component.
26 images = torch.load("images/stop_sign_pixel_values.pt")
27
28 outputs = llm.generate(prompt,
29 multi_modal_data=MultiModalData(
30 type=MultiModalData.Type.IMAGE, data=images))
31 for o in outputs:
32 generated_text = o.outputs[0].text
33 print(generated_text)
34
35
36def run_llava_image_features():
37 llm = LLM(
38 model="llava-hf/llava-1.5-7b-hf",
39 image_input_type="image_features",
40 image_token_id=32000,
41 image_input_shape="1,576,1024",
42 image_feature_size=576,
43 )
44
45 prompt = "<image>" * 576 + (
46 "\nUSER: What is the content of this image?\nASSISTANT:")
47
48 # This should be provided by another online or offline component.
49 images = torch.load("images/stop_sign_image_features.pt")
50
51 outputs = llm.generate(prompt,
52 multi_modal_data=MultiModalData(
53 type=MultiModalData.Type.IMAGE, data=images))
54 for o in outputs:
55 generated_text = o.outputs[0].text
56 print(generated_text)
57
58
59def main(args):
60 if args.type == "pixel_values":
61 run_llava_pixel_values()
62 else:
63 run_llava_image_features()
64
65
66if __name__ == "__main__":
67 parser = argparse.ArgumentParser(description="Demo on Llava")
68 parser.add_argument("--type",
69 type=str,
70 choices=["pixel_values", "image_features"],
71 default="pixel_values",
72 help="image input type")
73 args = parser.parse_args()
74 # Download from s3
75 s3_bucket_path = "s3://air-example-data-2/vllm_opensource_llava/"
76 local_directory = "images"
77
78 # Make sure the local directory exists or create it
79 os.makedirs(local_directory, exist_ok=True)
80
81 # Use AWS CLI to sync the directory, assume anonymous access
82 subprocess.check_call([
83 "aws",
84 "s3",
85 "sync",
86 s3_bucket_path,
87 local_directory,
88 "--no-sign-request",
89 ])
90 main(args)