Llava Example#

Source vllm-project/vllm.

 1import argparse
 2import os
 3import subprocess
 4
 5import torch
 6
 7from vllm import LLM
 8from vllm.sequence import MultiModalData
 9
10# The assets are located at `s3://air-example-data-2/vllm_opensource_llava/`.
11
12
13def run_llava_pixel_values():
14    llm = LLM(
15        model="llava-hf/llava-1.5-7b-hf",
16        image_input_type="pixel_values",
17        image_token_id=32000,
18        image_input_shape="1,3,336,336",
19        image_feature_size=576,
20    )
21
22    prompt = "<image>" * 576 + (
23        "\nUSER: What is the content of this image?\nASSISTANT:")
24
25    # This should be provided by another online or offline component.
26    image = torch.load("images/stop_sign_pixel_values.pt")
27
28    outputs = llm.generate({
29        "prompt":
30        prompt,
31        "multi_modal_data":
32        MultiModalData(type=MultiModalData.Type.IMAGE, data=image),
33    })
34
35    for o in outputs:
36        generated_text = o.outputs[0].text
37        print(generated_text)
38
39
40def run_llava_image_features():
41    llm = LLM(
42        model="llava-hf/llava-1.5-7b-hf",
43        image_input_type="image_features",
44        image_token_id=32000,
45        image_input_shape="1,576,1024",
46        image_feature_size=576,
47    )
48
49    prompt = "<image>" * 576 + (
50        "\nUSER: What is the content of this image?\nASSISTANT:")
51
52    # This should be provided by another online or offline component.
53    image = torch.load("images/stop_sign_image_features.pt")
54
55    outputs = llm.generate({
56        "prompt":
57        prompt,
58        "multi_modal_data":
59        MultiModalData(type=MultiModalData.Type.IMAGE, data=image),
60    })
61    for o in outputs:
62        generated_text = o.outputs[0].text
63        print(generated_text)
64
65
66def main(args):
67    if args.type == "pixel_values":
68        run_llava_pixel_values()
69    else:
70        run_llava_image_features()
71
72
73if __name__ == "__main__":
74    parser = argparse.ArgumentParser(description="Demo on Llava")
75    parser.add_argument("--type",
76                        type=str,
77                        choices=["pixel_values", "image_features"],
78                        default="pixel_values",
79                        help="image input type")
80    args = parser.parse_args()
81    # Download from s3
82    s3_bucket_path = "s3://air-example-data-2/vllm_opensource_llava/"
83    local_directory = "images"
84
85    # Make sure the local directory exists or create it
86    os.makedirs(local_directory, exist_ok=True)
87
88    # Use AWS CLI to sync the directory, assume anonymous access
89    subprocess.check_call([
90        "aws",
91        "s3",
92        "sync",
93        s3_bucket_path,
94        local_directory,
95        "--no-sign-request",
96    ])
97    main(args)