Llava Example#

Source vllm-project/vllm.

 1import argparse
 2import os
 3import subprocess
 4
 5import torch
 6
 7from vllm import LLM
 8from vllm.sequence import MultiModalData
 9
10# The assets are located at `s3://air-example-data-2/vllm_opensource_llava/`.
11
12
13def run_llava_pixel_values():
14    llm = LLM(
15        model="llava-hf/llava-1.5-7b-hf",
16        image_input_type="pixel_values",
17        image_token_id=32000,
18        image_input_shape="1,3,336,336",
19        image_feature_size=576,
20    )
21
22    prompt = "<image>" * 576 + (
23        "\nUSER: What is the content of this image?\nASSISTANT:")
24
25    # This should be provided by another online or offline component.
26    images = torch.load("images/stop_sign_pixel_values.pt")
27
28    outputs = llm.generate(prompt,
29                           multi_modal_data=MultiModalData(
30                               type=MultiModalData.Type.IMAGE, data=images))
31    for o in outputs:
32        generated_text = o.outputs[0].text
33        print(generated_text)
34
35
36def run_llava_image_features():
37    llm = LLM(
38        model="llava-hf/llava-1.5-7b-hf",
39        image_input_type="image_features",
40        image_token_id=32000,
41        image_input_shape="1,576,1024",
42        image_feature_size=576,
43    )
44
45    prompt = "<image>" * 576 + (
46        "\nUSER: What is the content of this image?\nASSISTANT:")
47
48    # This should be provided by another online or offline component.
49    images = torch.load("images/stop_sign_image_features.pt")
50
51    outputs = llm.generate(prompt,
52                           multi_modal_data=MultiModalData(
53                               type=MultiModalData.Type.IMAGE, data=images))
54    for o in outputs:
55        generated_text = o.outputs[0].text
56        print(generated_text)
57
58
59def main(args):
60    if args.type == "pixel_values":
61        run_llava_pixel_values()
62    else:
63        run_llava_image_features()
64
65
66if __name__ == "__main__":
67    parser = argparse.ArgumentParser(description="Demo on Llava")
68    parser.add_argument("--type",
69                        type=str,
70                        choices=["pixel_values", "image_features"],
71                        default="pixel_values",
72                        help="image input type")
73    args = parser.parse_args()
74    # Download from s3
75    s3_bucket_path = "s3://air-example-data-2/vllm_opensource_llava/"
76    local_directory = "images"
77
78    # Make sure the local directory exists or create it
79    os.makedirs(local_directory, exist_ok=True)
80
81    # Use AWS CLI to sync the directory, assume anonymous access
82    subprocess.check_call([
83        "aws",
84        "s3",
85        "sync",
86        s3_bucket_path,
87        local_directory,
88        "--no-sign-request",
89    ])
90    main(args)