Llava Example#

Source vllm-project/vllm.

 1import argparse
 2import os
 3import subprocess
 4
 5import torch
 6from PIL import Image
 7
 8from vllm import LLM
 9from vllm.multimodal.image import ImageFeatureData, ImagePixelData
10
11# The assets are located at `s3://air-example-data-2/vllm_opensource_llava/`.
12# You can use `.buildkite/download-images.sh` to download them
13
14
15def run_llava_pixel_values(*, disable_image_processor: bool = False):
16    llm = LLM(
17        model="llava-hf/llava-1.5-7b-hf",
18        image_input_type="pixel_values",
19        image_token_id=32000,
20        image_input_shape="1,3,336,336",
21        image_feature_size=576,
22        disable_image_processor=disable_image_processor,
23    )
24
25    prompt = "<image>" * 576 + (
26        "\nUSER: What is the content of this image?\nASSISTANT:")
27
28    if disable_image_processor:
29        image = torch.load("images/stop_sign_pixel_values.pt")
30    else:
31        image = Image.open("images/stop_sign.jpg")
32
33    outputs = llm.generate({
34        "prompt": prompt,
35        "multi_modal_data": ImagePixelData(image),
36    })
37
38    for o in outputs:
39        generated_text = o.outputs[0].text
40        print(generated_text)
41
42
43def run_llava_image_features():
44    llm = LLM(
45        model="llava-hf/llava-1.5-7b-hf",
46        image_input_type="image_features",
47        image_token_id=32000,
48        image_input_shape="1,576,1024",
49        image_feature_size=576,
50    )
51
52    prompt = "<image>" * 576 + (
53        "\nUSER: What is the content of this image?\nASSISTANT:")
54
55    image: torch.Tensor = torch.load("images/stop_sign_image_features.pt")
56
57    outputs = llm.generate({
58        "prompt": prompt,
59        "multi_modal_data": ImageFeatureData(image),
60    })
61
62    for o in outputs:
63        generated_text = o.outputs[0].text
64        print(generated_text)
65
66
67def main(args):
68    if args.type == "pixel_values":
69        run_llava_pixel_values()
70    else:
71        run_llava_image_features()
72
73
74if __name__ == "__main__":
75    parser = argparse.ArgumentParser(description="Demo on Llava")
76    parser.add_argument("--type",
77                        type=str,
78                        choices=["pixel_values", "image_features"],
79                        default="pixel_values",
80                        help="image input type")
81    args = parser.parse_args()
82    # Download from s3
83    s3_bucket_path = "s3://air-example-data-2/vllm_opensource_llava/"
84    local_directory = "images"
85
86    # Make sure the local directory exists or create it
87    os.makedirs(local_directory, exist_ok=True)
88
89    # Use AWS CLI to sync the directory, assume anonymous access
90    subprocess.check_call([
91        "aws",
92        "s3",
93        "sync",
94        s3_bucket_path,
95        local_directory,
96        "--no-sign-request",
97    ])
98    main(args)