Offline Inference Vision Language#
Source vllm-project/vllm.
1"""
2This example shows how to use vLLM for running offline inference
3with the correct prompt format on vision language models.
4
5For most models, the prompt format should follow corresponding examples
6on HuggingFace model repository.
7"""
8from transformers import AutoTokenizer
9
10from vllm import LLM, SamplingParams
11from vllm.assets.image import ImageAsset
12from vllm.utils import FlexibleArgumentParser
13
14# Input image and question
15image = ImageAsset("cherry_blossom").pil_image.convert("RGB")
16question = "What is the content of this image?"
17
18
19# LLaVA-1.5
20def run_llava(question):
21
22 prompt = f"USER: <image>\n{question}\nASSISTANT:"
23
24 llm = LLM(model="llava-hf/llava-1.5-7b-hf")
25 stop_token_ids = None
26 return llm, prompt, stop_token_ids
27
28
29# LLaVA-1.6/LLaVA-NeXT
30def run_llava_next(question):
31
32 prompt = f"[INST] <image>\n{question} [/INST]"
33 llm = LLM(model="llava-hf/llava-v1.6-mistral-7b-hf")
34 stop_token_ids = None
35 return llm, prompt, stop_token_ids
36
37
38# Fuyu
39def run_fuyu(question):
40
41 prompt = f"{question}\n"
42 llm = LLM(model="adept/fuyu-8b")
43 stop_token_ids = None
44 return llm, prompt, stop_token_ids
45
46
47# Phi-3-Vision
48def run_phi3v(question):
49
50 prompt = f"<|user|>\n<|image_1|>\n{question}<|end|>\n<|assistant|>\n" # noqa: E501
51 # Note: The default setting of max_num_seqs (256) and
52 # max_model_len (128k) for this model may cause OOM.
53 # You may lower either to run this example on lower-end GPUs.
54
55 # In this example, we override max_num_seqs to 5 while
56 # keeping the original context length of 128k.
57 llm = LLM(
58 model="microsoft/Phi-3-vision-128k-instruct",
59 trust_remote_code=True,
60 max_num_seqs=5,
61 )
62 stop_token_ids = None
63 return llm, prompt, stop_token_ids
64
65
66# PaliGemma
67def run_paligemma(question):
68
69 # PaliGemma has special prompt format for VQA
70 prompt = "caption en"
71 llm = LLM(model="google/paligemma-3b-mix-224")
72 stop_token_ids = None
73 return llm, prompt, stop_token_ids
74
75
76# Chameleon
77def run_chameleon(question):
78
79 prompt = f"{question}<image>"
80 llm = LLM(model="facebook/chameleon-7b")
81 stop_token_ids = None
82 return llm, prompt, stop_token_ids
83
84
85# MiniCPM-V
86def run_minicpmv(question):
87
88 # 2.0
89 # The official repo doesn't work yet, so we need to use a fork for now
90 # For more details, please see: See: https://github.com/vllm-project/vllm/pull/4087#issuecomment-2250397630 # noqa
91 # model_name = "HwwwH/MiniCPM-V-2"
92
93 # 2.5
94 # model_name = "openbmb/MiniCPM-Llama3-V-2_5"
95
96 #2.6
97 model_name = "openbmb/MiniCPM-V-2_6"
98 tokenizer = AutoTokenizer.from_pretrained(model_name,
99 trust_remote_code=True)
100 llm = LLM(
101 model=model_name,
102 trust_remote_code=True,
103 )
104 # NOTE The stop_token_ids are different for various versions of MiniCPM-V
105 # 2.0
106 # stop_token_ids = [tokenizer.eos_id]
107
108 # 2.5
109 # stop_token_ids = [tokenizer.eos_id, tokenizer.eot_id]
110
111 # 2.6
112 stop_tokens = ['<|im_end|>', '<|endoftext|>']
113 stop_token_ids = [tokenizer.convert_tokens_to_ids(i) for i in stop_tokens]
114
115 messages = [{
116 'role': 'user',
117 'content': f'(<image>./</image>)\n{question}'
118 }]
119 prompt = tokenizer.apply_chat_template(messages,
120 tokenize=False,
121 add_generation_prompt=True)
122 return llm, prompt, stop_token_ids
123
124
125# InternVL
126def run_internvl(question):
127 model_name = "OpenGVLab/InternVL2-2B"
128
129 llm = LLM(
130 model=model_name,
131 trust_remote_code=True,
132 max_num_seqs=5,
133 )
134
135 tokenizer = AutoTokenizer.from_pretrained(model_name,
136 trust_remote_code=True)
137 messages = [{'role': 'user', 'content': f"<image>\n{question}"}]
138 prompt = tokenizer.apply_chat_template(messages,
139 tokenize=False,
140 add_generation_prompt=True)
141
142 # Stop tokens for InternVL
143 # models variants may have different stop tokens
144 # please refer to the model card for the correct "stop words":
145 # https://huggingface.co/OpenGVLab/InternVL2-2B#service
146 stop_tokens = ["<|endoftext|>", "<|im_start|>", "<|im_end|>", "<|end|>"]
147 stop_token_ids = [tokenizer.convert_tokens_to_ids(i) for i in stop_tokens]
148 return llm, prompt, stop_token_ids
149
150
151# BLIP-2
152def run_blip2(question):
153
154 # BLIP-2 prompt format is inaccurate on HuggingFace model repository.
155 # See https://huggingface.co/Salesforce/blip2-opt-2.7b/discussions/15#64ff02f3f8cf9e4f5b038262 #noqa
156 prompt = f"Question: {question} Answer:"
157 llm = LLM(model="Salesforce/blip2-opt-2.7b")
158 stop_token_ids = None
159 return llm, prompt, stop_token_ids
160
161
162model_example_map = {
163 "llava": run_llava,
164 "llava-next": run_llava_next,
165 "fuyu": run_fuyu,
166 "phi3_v": run_phi3v,
167 "paligemma": run_paligemma,
168 "chameleon": run_chameleon,
169 "minicpmv": run_minicpmv,
170 "blip-2": run_blip2,
171 "internvl_chat": run_internvl,
172}
173
174
175def main(args):
176 model = args.model_type
177 if model not in model_example_map:
178 raise ValueError(f"Model type {model} is not supported.")
179
180 llm, prompt, stop_token_ids = model_example_map[model](question)
181
182 # We set temperature to 0.2 so that outputs can be different
183 # even when all prompts are identical when running batch inference.
184 sampling_params = SamplingParams(temperature=0.2,
185 max_tokens=64,
186 stop_token_ids=stop_token_ids)
187
188 assert args.num_prompts > 0
189 if args.num_prompts == 1:
190 # Single inference
191 inputs = {
192 "prompt": prompt,
193 "multi_modal_data": {
194 "image": image
195 },
196 }
197
198 else:
199 # Batch inference
200 inputs = [{
201 "prompt": prompt,
202 "multi_modal_data": {
203 "image": image
204 },
205 } for _ in range(args.num_prompts)]
206
207 outputs = llm.generate(inputs, sampling_params=sampling_params)
208
209 for o in outputs:
210 generated_text = o.outputs[0].text
211 print(generated_text)
212
213
214if __name__ == "__main__":
215 parser = FlexibleArgumentParser(
216 description='Demo on using vLLM for offline inference with '
217 'vision language models')
218 parser.add_argument('--model-type',
219 '-m',
220 type=str,
221 default="llava",
222 choices=model_example_map.keys(),
223 help='Huggingface "model_type".')
224 parser.add_argument('--num-prompts',
225 type=int,
226 default=1,
227 help='Number of prompts to run.')
228
229 args = parser.parse_args()
230 main(args)