Offline Inference Vision Language#
Source vllm-project/vllm.
1"""
2This example shows how to use vLLM for running offline inference
3with the correct prompt format on vision language models.
4
5For most models, the prompt format should follow corresponding examples
6on HuggingFace model repository.
7"""
8from transformers import AutoTokenizer
9
10from vllm import LLM, SamplingParams
11from vllm.assets.image import ImageAsset
12from vllm.assets.video import VideoAsset
13from vllm.utils import FlexibleArgumentParser
14
15
16# LLaVA-1.5
17def run_llava(question, modality):
18 assert modality == "image"
19
20 prompt = f"USER: <image>\n{question}\nASSISTANT:"
21
22 llm = LLM(model="llava-hf/llava-1.5-7b-hf")
23 stop_token_ids = None
24 return llm, prompt, stop_token_ids
25
26
27# LLaVA-1.6/LLaVA-NeXT
28def run_llava_next(question, modality):
29 assert modality == "image"
30
31 prompt = f"[INST] <image>\n{question} [/INST]"
32 llm = LLM(model="llava-hf/llava-v1.6-mistral-7b-hf", max_model_len=8192)
33 stop_token_ids = None
34 return llm, prompt, stop_token_ids
35
36
37# LlaVA-NeXT-Video
38# Currently only support for video input
39def run_llava_next_video(question, modality):
40 assert modality == "video"
41
42 prompt = f"USER: <video>\n{question} ASSISTANT:"
43 llm = LLM(model="llava-hf/LLaVA-NeXT-Video-7B-hf", max_model_len=8192)
44 stop_token_ids = None
45 return llm, prompt, stop_token_ids
46
47
48# LLaVA-OneVision
49def run_llava_onevision(question, modality):
50
51 if modality == "video":
52 prompt = f"<|im_start|>user <video>\n{question}<|im_end|> \
53 <|im_start|>assistant\n"
54
55 elif modality == "image":
56 prompt = f"<|im_start|>user <image>\n{question}<|im_end|> \
57 <|im_start|>assistant\n"
58
59 llm = LLM(model="llava-hf/llava-onevision-qwen2-7b-ov-hf",
60 max_model_len=32768)
61 stop_token_ids = None
62 return llm, prompt, stop_token_ids
63
64
65# Fuyu
66def run_fuyu(question, modality):
67 assert modality == "image"
68
69 prompt = f"{question}\n"
70 llm = LLM(model="adept/fuyu-8b")
71 stop_token_ids = None
72 return llm, prompt, stop_token_ids
73
74
75# Phi-3-Vision
76def run_phi3v(question, modality):
77 assert modality == "image"
78
79 prompt = f"<|user|>\n<|image_1|>\n{question}<|end|>\n<|assistant|>\n" # noqa: E501
80 # Note: The default setting of max_num_seqs (256) and
81 # max_model_len (128k) for this model may cause OOM.
82 # You may lower either to run this example on lower-end GPUs.
83
84 # In this example, we override max_num_seqs to 5 while
85 # keeping the original context length of 128k.
86
87 # num_crops is an override kwarg to the multimodal image processor;
88 # For some models, e.g., Phi-3.5-vision-instruct, it is recommended
89 # to use 16 for single frame scenarios, and 4 for multi-frame.
90 #
91 # Generally speaking, a larger value for num_crops results in more
92 # tokens per image instance, because it may scale the image more in
93 # the image preprocessing. Some references in the model docs and the
94 # formula for image tokens after the preprocessing
95 # transform can be found below.
96 #
97 # https://huggingface.co/microsoft/Phi-3.5-vision-instruct#loading-the-model-locally
98 # https://huggingface.co/microsoft/Phi-3.5-vision-instruct/blob/main/processing_phi3_v.py#L194
99 llm = LLM(
100 model="microsoft/Phi-3-vision-128k-instruct",
101 trust_remote_code=True,
102 max_num_seqs=5,
103 mm_processor_kwargs={"num_crops": 16},
104 )
105 stop_token_ids = None
106 return llm, prompt, stop_token_ids
107
108
109# PaliGemma
110def run_paligemma(question, modality):
111 assert modality == "image"
112
113 # PaliGemma has special prompt format for VQA
114 prompt = "caption en"
115 llm = LLM(model="google/paligemma-3b-mix-224")
116 stop_token_ids = None
117 return llm, prompt, stop_token_ids
118
119
120# Chameleon
121def run_chameleon(question, modality):
122 assert modality == "image"
123
124 prompt = f"{question}<image>"
125 llm = LLM(model="facebook/chameleon-7b")
126 stop_token_ids = None
127 return llm, prompt, stop_token_ids
128
129
130# MiniCPM-V
131def run_minicpmv(question, modality):
132 assert modality == "image"
133
134 # 2.0
135 # The official repo doesn't work yet, so we need to use a fork for now
136 # For more details, please see: See: https://github.com/vllm-project/vllm/pull/4087#issuecomment-2250397630 # noqa
137 # model_name = "HwwwH/MiniCPM-V-2"
138
139 # 2.5
140 # model_name = "openbmb/MiniCPM-Llama3-V-2_5"
141
142 #2.6
143 model_name = "openbmb/MiniCPM-V-2_6"
144 tokenizer = AutoTokenizer.from_pretrained(model_name,
145 trust_remote_code=True)
146 llm = LLM(
147 model=model_name,
148 trust_remote_code=True,
149 )
150 # NOTE The stop_token_ids are different for various versions of MiniCPM-V
151 # 2.0
152 # stop_token_ids = [tokenizer.eos_id]
153
154 # 2.5
155 # stop_token_ids = [tokenizer.eos_id, tokenizer.eot_id]
156
157 # 2.6
158 stop_tokens = ['<|im_end|>', '<|endoftext|>']
159 stop_token_ids = [tokenizer.convert_tokens_to_ids(i) for i in stop_tokens]
160
161 messages = [{
162 'role': 'user',
163 'content': f'(<image>./</image>)\n{question}'
164 }]
165 prompt = tokenizer.apply_chat_template(messages,
166 tokenize=False,
167 add_generation_prompt=True)
168 return llm, prompt, stop_token_ids
169
170
171# InternVL
172def run_internvl(question, modality):
173 assert modality == "image"
174
175 model_name = "OpenGVLab/InternVL2-2B"
176
177 llm = LLM(
178 model=model_name,
179 trust_remote_code=True,
180 max_num_seqs=5,
181 )
182
183 tokenizer = AutoTokenizer.from_pretrained(model_name,
184 trust_remote_code=True)
185 messages = [{'role': 'user', 'content': f"<image>\n{question}"}]
186 prompt = tokenizer.apply_chat_template(messages,
187 tokenize=False,
188 add_generation_prompt=True)
189
190 # Stop tokens for InternVL
191 # models variants may have different stop tokens
192 # please refer to the model card for the correct "stop words":
193 # https://huggingface.co/OpenGVLab/InternVL2-2B#service
194 stop_tokens = ["<|endoftext|>", "<|im_start|>", "<|im_end|>", "<|end|>"]
195 stop_token_ids = [tokenizer.convert_tokens_to_ids(i) for i in stop_tokens]
196 return llm, prompt, stop_token_ids
197
198
199# BLIP-2
200def run_blip2(question, modality):
201 assert modality == "image"
202
203 # BLIP-2 prompt format is inaccurate on HuggingFace model repository.
204 # See https://huggingface.co/Salesforce/blip2-opt-2.7b/discussions/15#64ff02f3f8cf9e4f5b038262 #noqa
205 prompt = f"Question: {question} Answer:"
206 llm = LLM(model="Salesforce/blip2-opt-2.7b")
207 stop_token_ids = None
208 return llm, prompt, stop_token_ids
209
210
211# Qwen
212def run_qwen_vl(question, modality):
213 assert modality == "image"
214
215 llm = LLM(
216 model="Qwen/Qwen-VL",
217 trust_remote_code=True,
218 max_num_seqs=5,
219 )
220
221 prompt = f"{question}Picture 1: <img></img>\n"
222 stop_token_ids = None
223 return llm, prompt, stop_token_ids
224
225
226# Qwen2-VL
227def run_qwen2_vl(question, modality):
228 assert modality == "image"
229
230 model_name = "Qwen/Qwen2-VL-7B-Instruct"
231
232 llm = LLM(
233 model=model_name,
234 max_num_seqs=5,
235 )
236
237 prompt = ("<|im_start|>system\nYou are a helpful assistant.<|im_end|>\n"
238 "<|im_start|>user\n<|vision_start|><|image_pad|><|vision_end|>"
239 f"{question}<|im_end|>\n"
240 "<|im_start|>assistant\n")
241 stop_token_ids = None
242 return llm, prompt, stop_token_ids
243
244
245# LLama
246def run_mllama(question, modality):
247 assert modality == "image"
248
249 model_name = "meta-llama/Llama-3.2-11B-Vision-Instruct"
250
251 # Note: The default setting of max_num_seqs (256) and
252 # max_model_len (131072) for this model may cause OOM.
253 # You may lower either to run this example on lower-end GPUs.
254
255 # The configuration below has been confirmed to launch on a
256 # single H100 GPU.
257 llm = LLM(
258 model=model_name,
259 max_num_seqs=16,
260 enforce_eager=True,
261 )
262
263 prompt = f"<|image|><|begin_of_text|>{question}"
264 stop_token_ids = None
265 return llm, prompt, stop_token_ids
266
267
268model_example_map = {
269 "llava": run_llava,
270 "llava-next": run_llava_next,
271 "llava-next-video": run_llava_next_video,
272 "llava-onevision": run_llava_onevision,
273 "fuyu": run_fuyu,
274 "phi3_v": run_phi3v,
275 "paligemma": run_paligemma,
276 "chameleon": run_chameleon,
277 "minicpmv": run_minicpmv,
278 "blip-2": run_blip2,
279 "internvl_chat": run_internvl,
280 "qwen_vl": run_qwen_vl,
281 "qwen2_vl": run_qwen2_vl,
282 "mllama": run_mllama,
283}
284
285
286def get_multi_modal_input(args):
287 """
288 return {
289 "data": image or video,
290 "question": question,
291 }
292 """
293 if args.modality == "image":
294 # Input image and question
295 image = ImageAsset("cherry_blossom") \
296 .pil_image.convert("RGB")
297 img_question = "What is the content of this image?"
298
299 return {
300 "data": image,
301 "question": img_question,
302 }
303
304 if args.modality == "video":
305 # Input video and question
306 video = VideoAsset(name="sample_demo_1.mp4",
307 num_frames=args.num_frames).np_ndarrays
308 vid_question = "Why is this video funny?"
309
310 return {
311 "data": video,
312 "question": vid_question,
313 }
314
315 msg = f"Modality {args.modality} is not supported."
316 raise ValueError(msg)
317
318
319def main(args):
320 model = args.model_type
321 if model not in model_example_map:
322 raise ValueError(f"Model type {model} is not supported.")
323
324 modality = args.modality
325 mm_input = get_multi_modal_input(args)
326 data = mm_input["data"]
327 question = mm_input["question"]
328
329 llm, prompt, stop_token_ids = model_example_map[model](question, modality)
330
331 # We set temperature to 0.2 so that outputs can be different
332 # even when all prompts are identical when running batch inference.
333 sampling_params = SamplingParams(temperature=0.2,
334 max_tokens=64,
335 stop_token_ids=stop_token_ids)
336
337 assert args.num_prompts > 0
338 if args.num_prompts == 1:
339 # Single inference
340 inputs = {
341 "prompt": prompt,
342 "multi_modal_data": {
343 modality: data
344 },
345 }
346
347 else:
348 # Batch inference
349 inputs = [{
350 "prompt": prompt,
351 "multi_modal_data": {
352 modality: data
353 },
354 } for _ in range(args.num_prompts)]
355
356 outputs = llm.generate(inputs, sampling_params=sampling_params)
357
358 for o in outputs:
359 generated_text = o.outputs[0].text
360 print(generated_text)
361
362
363if __name__ == "__main__":
364 parser = FlexibleArgumentParser(
365 description='Demo on using vLLM for offline inference with '
366 'vision language models')
367 parser.add_argument('--model-type',
368 '-m',
369 type=str,
370 default="llava",
371 choices=model_example_map.keys(),
372 help='Huggingface "model_type".')
373 parser.add_argument('--num-prompts',
374 type=int,
375 default=4,
376 help='Number of prompts to run.')
377 parser.add_argument('--modality',
378 type=str,
379 default="image",
380 choices=['image', 'video'],
381 help='Modality of the input.')
382 parser.add_argument('--num-frames',
383 type=int,
384 default=16,
385 help='Number of frames to extract from the video.')
386 args = parser.parse_args()
387 main(args)