Offline Inference Vision Language Multi Image#
Source vllm-project/vllm.
1"""
2This example shows how to use vLLM for running offline inference with
3multi-image input on vision language models for text generation,
4using the chat template defined by the model.
5"""
6from argparse import Namespace
7from typing import List, NamedTuple, Optional
8
9from PIL.Image import Image
10from transformers import AutoProcessor, AutoTokenizer
11
12from vllm import LLM, SamplingParams
13from vllm.multimodal.utils import fetch_image
14from vllm.utils import FlexibleArgumentParser
15
16QUESTION = "What is the content of each image?"
17IMAGE_URLS = [
18 "https://upload.wikimedia.org/wikipedia/commons/d/da/2015_Kaczka_krzy%C5%BCowka_w_wodzie_%28samiec%29.jpg",
19 "https://upload.wikimedia.org/wikipedia/commons/7/77/002_The_lion_king_Snyggve_in_the_Serengeti_National_Park_Photo_by_Giles_Laurent.jpg",
20]
21
22
23class ModelRequestData(NamedTuple):
24 llm: LLM
25 prompt: str
26 stop_token_ids: Optional[List[str]]
27 image_data: List[Image]
28 chat_template: Optional[str]
29
30
31# NOTE: The default `max_num_seqs` and `max_model_len` may result in OOM on
32# lower-end GPUs.
33# Unless specified, these settings have been tested to work on a single L4.
34
35
36def load_qwenvl_chat(question: str, image_urls: List[str]) -> ModelRequestData:
37 model_name = "Qwen/Qwen-VL-Chat"
38 llm = LLM(
39 model=model_name,
40 trust_remote_code=True,
41 max_model_len=1024,
42 max_num_seqs=2,
43 limit_mm_per_prompt={"image": len(image_urls)},
44 )
45 placeholders = "".join(f"Picture {i}: <img></img>\n"
46 for i, _ in enumerate(image_urls, start=1))
47
48 # This model does not have a chat_template attribute on its tokenizer,
49 # so we need to explicitly pass it. We use ChatML since it's used in the
50 # generation utils of the model:
51 # https://huggingface.co/Qwen/Qwen-VL-Chat/blob/main/qwen_generation_utils.py#L265
52 tokenizer = AutoTokenizer.from_pretrained(model_name,
53 trust_remote_code=True)
54
55 # Copied from: https://huggingface.co/docs/transformers/main/en/chat_templating
56 chat_template = "{% if not add_generation_prompt is defined %}{% set add_generation_prompt = false %}{% endif %}{% for message in messages %}{{'<|im_start|>' + message['role'] + '\n' + message['content'] + '<|im_end|>' + '\n'}}{% endfor %}{% if add_generation_prompt %}{{ '<|im_start|>assistant\n' }}{% endif %}" # noqa: E501
57
58 messages = [{'role': 'user', 'content': f"{placeholders}\n{question}"}]
59 prompt = tokenizer.apply_chat_template(messages,
60 tokenize=False,
61 add_generation_prompt=True,
62 chat_template=chat_template)
63
64 stop_tokens = ["<|endoftext|>", "<|im_start|>", "<|im_end|>"]
65 stop_token_ids = [tokenizer.convert_tokens_to_ids(i) for i in stop_tokens]
66 return ModelRequestData(
67 llm=llm,
68 prompt=prompt,
69 stop_token_ids=stop_token_ids,
70 image_data=[fetch_image(url) for url in image_urls],
71 chat_template=chat_template,
72 )
73
74
75def load_phi3v(question: str, image_urls: List[str]) -> ModelRequestData:
76 # num_crops is an override kwarg to the multimodal image processor;
77 # For some models, e.g., Phi-3.5-vision-instruct, it is recommended
78 # to use 16 for single frame scenarios, and 4 for multi-frame.
79 #
80 # Generally speaking, a larger value for num_crops results in more
81 # tokens per image instance, because it may scale the image more in
82 # the image preprocessing. Some references in the model docs and the
83 # formula for image tokens after the preprocessing
84 # transform can be found below.
85 #
86 # https://huggingface.co/microsoft/Phi-3.5-vision-instruct#loading-the-model-locally
87 # https://huggingface.co/microsoft/Phi-3.5-vision-instruct/blob/main/processing_phi3_v.py#L194
88 llm = LLM(
89 model="microsoft/Phi-3.5-vision-instruct",
90 trust_remote_code=True,
91 max_model_len=4096,
92 max_num_seqs=2,
93 limit_mm_per_prompt={"image": len(image_urls)},
94 mm_processor_kwargs={"num_crops": 4},
95 )
96 placeholders = "\n".join(f"<|image_{i}|>"
97 for i, _ in enumerate(image_urls, start=1))
98 prompt = f"<|user|>\n{placeholders}\n{question}<|end|>\n<|assistant|>\n"
99 stop_token_ids = None
100
101 return ModelRequestData(
102 llm=llm,
103 prompt=prompt,
104 stop_token_ids=stop_token_ids,
105 image_data=[fetch_image(url) for url in image_urls],
106 chat_template=None,
107 )
108
109
110def load_h2onvl(question: str, image_urls: List[str]) -> ModelRequestData:
111 model_name = "h2oai/h2ovl-mississippi-2b"
112
113 llm = LLM(
114 model=model_name,
115 trust_remote_code=True,
116 max_model_len=8192,
117 limit_mm_per_prompt={"image": len(image_urls)},
118 mm_processor_kwargs={"max_dynamic_patch": 4},
119 )
120
121 placeholders = "\n".join(f"Image-{i}: <image>\n"
122 for i, _ in enumerate(image_urls, start=1))
123 messages = [{'role': 'user', 'content': f"{placeholders}\n{question}"}]
124
125 tokenizer = AutoTokenizer.from_pretrained(model_name,
126 trust_remote_code=True)
127 prompt = tokenizer.apply_chat_template(messages,
128 tokenize=False,
129 add_generation_prompt=True)
130
131 # Stop tokens for H2OVL-Mississippi
132 # https://huggingface.co/h2oai/h2ovl-mississippi-2b
133 stop_token_ids = [tokenizer.eos_token_id]
134
135 return ModelRequestData(
136 llm=llm,
137 prompt=prompt,
138 stop_token_ids=stop_token_ids,
139 image_data=[fetch_image(url) for url in image_urls],
140 chat_template=None,
141 )
142
143
144def load_internvl(question: str, image_urls: List[str]) -> ModelRequestData:
145 model_name = "OpenGVLab/InternVL2-2B"
146
147 llm = LLM(
148 model=model_name,
149 trust_remote_code=True,
150 max_model_len=4096,
151 limit_mm_per_prompt={"image": len(image_urls)},
152 mm_processor_kwargs={"max_dynamic_patch": 4},
153 )
154
155 placeholders = "\n".join(f"Image-{i}: <image>\n"
156 for i, _ in enumerate(image_urls, start=1))
157 messages = [{'role': 'user', 'content': f"{placeholders}\n{question}"}]
158
159 tokenizer = AutoTokenizer.from_pretrained(model_name,
160 trust_remote_code=True)
161 prompt = tokenizer.apply_chat_template(messages,
162 tokenize=False,
163 add_generation_prompt=True)
164
165 # Stop tokens for InternVL
166 # models variants may have different stop tokens
167 # please refer to the model card for the correct "stop words":
168 # https://huggingface.co/OpenGVLab/InternVL2-2B/blob/main/conversation.py
169 stop_tokens = ["<|endoftext|>", "<|im_start|>", "<|im_end|>", "<|end|>"]
170 stop_token_ids = [tokenizer.convert_tokens_to_ids(i) for i in stop_tokens]
171
172 return ModelRequestData(
173 llm=llm,
174 prompt=prompt,
175 stop_token_ids=stop_token_ids,
176 image_data=[fetch_image(url) for url in image_urls],
177 chat_template=None,
178 )
179
180
181def load_nvlm_d(question: str, image_urls: List[str]):
182 model_name = "nvidia/NVLM-D-72B"
183
184 # Adjust this as necessary to fit in GPU
185 llm = LLM(
186 model=model_name,
187 trust_remote_code=True,
188 max_model_len=8192,
189 tensor_parallel_size=4,
190 limit_mm_per_prompt={"image": len(image_urls)},
191 mm_processor_kwargs={"max_dynamic_patch": 4},
192 )
193
194 placeholders = "\n".join(f"Image-{i}: <image>\n"
195 for i, _ in enumerate(image_urls, start=1))
196 messages = [{'role': 'user', 'content': f"{placeholders}\n{question}"}]
197
198 tokenizer = AutoTokenizer.from_pretrained(model_name,
199 trust_remote_code=True)
200 prompt = tokenizer.apply_chat_template(messages,
201 tokenize=False,
202 add_generation_prompt=True)
203 stop_token_ids = None
204
205 return ModelRequestData(
206 llm=llm,
207 prompt=prompt,
208 stop_token_ids=stop_token_ids,
209 image_data=[fetch_image(url) for url in image_urls],
210 chat_template=None,
211 )
212
213
214def load_qwen2_vl(question, image_urls: List[str]) -> ModelRequestData:
215 try:
216 from qwen_vl_utils import process_vision_info
217 except ModuleNotFoundError:
218 print('WARNING: `qwen-vl-utils` not installed, input images will not '
219 'be automatically resized. You can enable this functionality by '
220 '`pip install qwen-vl-utils`.')
221 process_vision_info = None
222
223 model_name = "Qwen/Qwen2-VL-7B-Instruct"
224
225 # Tested on L40
226 llm = LLM(
227 model=model_name,
228 max_model_len=32768 if process_vision_info is None else 4096,
229 max_num_seqs=5,
230 limit_mm_per_prompt={"image": len(image_urls)},
231 )
232
233 placeholders = [{"type": "image", "image": url} for url in image_urls]
234 messages = [{
235 "role": "system",
236 "content": "You are a helpful assistant."
237 }, {
238 "role":
239 "user",
240 "content": [
241 *placeholders,
242 {
243 "type": "text",
244 "text": question
245 },
246 ],
247 }]
248
249 processor = AutoProcessor.from_pretrained(model_name)
250
251 prompt = processor.apply_chat_template(messages,
252 tokenize=False,
253 add_generation_prompt=True)
254
255 stop_token_ids = None
256
257 if process_vision_info is None:
258 image_data = [fetch_image(url) for url in image_urls]
259 else:
260 image_data, _ = process_vision_info(messages)
261
262 return ModelRequestData(
263 llm=llm,
264 prompt=prompt,
265 stop_token_ids=stop_token_ids,
266 image_data=image_data,
267 chat_template=None,
268 )
269
270
271def load_mllama(question, image_urls: List[str]) -> ModelRequestData:
272 model_name = "meta-llama/Llama-3.2-11B-Vision-Instruct"
273
274 # The configuration below has been confirmed to launch on a single L40 GPU.
275 llm = LLM(
276 model=model_name,
277 max_model_len=4096,
278 max_num_seqs=16,
279 enforce_eager=True,
280 limit_mm_per_prompt={"image": len(image_urls)},
281 )
282
283 prompt = f"<|image|><|image|><|begin_of_text|>{question}"
284 return ModelRequestData(
285 llm=llm,
286 prompt=prompt,
287 stop_token_ids=None,
288 image_data=[fetch_image(url) for url in image_urls],
289 chat_template=None,
290 )
291
292
293def load_idefics3(question, image_urls: List[str]) -> ModelRequestData:
294 model_name = "HuggingFaceM4/Idefics3-8B-Llama3"
295
296 # The configuration below has been confirmed to launch on a single L40 GPU.
297 llm = LLM(
298 model=model_name,
299 max_model_len=8192,
300 max_num_seqs=16,
301 enforce_eager=True,
302 limit_mm_per_prompt={"image": len(image_urls)},
303 # if you are running out of memory, you can reduce the "longest_edge".
304 # see: https://huggingface.co/HuggingFaceM4/Idefics3-8B-Llama3#model-optimizations
305 mm_processor_kwargs={
306 "size": {
307 "longest_edge": 2 * 364
308 },
309 },
310 )
311
312 placeholders = "\n".join(f"Image-{i}: <image>\n"
313 for i, _ in enumerate(image_urls, start=1))
314 prompt = f"<|begin_of_text|>User:{placeholders}\n{question}<end_of_utterance>\nAssistant:" # noqa: E501
315 return ModelRequestData(
316 llm=llm,
317 prompt=prompt,
318 stop_token_ids=None,
319 image_data=[fetch_image(url) for url in image_urls],
320 chat_template=None,
321 )
322
323
324def load_aria(question, image_urls: List[str]) -> ModelRequestData:
325 model_name = "rhymes-ai/Aria"
326 llm = LLM(model=model_name,
327 tokenizer_mode="slow",
328 trust_remote_code=True,
329 dtype="bfloat16",
330 limit_mm_per_prompt={"image": len(image_urls)})
331 placeholders = "<fim_prefix><|img|><fim_suffix>\n" * len(image_urls)
332 prompt = (f"<|im_start|>user\n{placeholders}{question}<|im_end|>\n"
333 "<|im_start|>assistant\n")
334 stop_token_ids = [93532, 93653, 944, 93421, 1019, 93653, 93519]
335 return ModelRequestData(
336 llm=llm,
337 prompt=prompt,
338 stop_token_ids=stop_token_ids,
339 image_data=[fetch_image(url) for url in image_urls],
340 chat_template=None)
341
342
343model_example_map = {
344 "phi3_v": load_phi3v,
345 "h2ovl_chat": load_h2onvl,
346 "internvl_chat": load_internvl,
347 "NVLM_D": load_nvlm_d,
348 "qwen2_vl": load_qwen2_vl,
349 "qwen_vl_chat": load_qwenvl_chat,
350 "mllama": load_mllama,
351 "idefics3": load_idefics3,
352 "aria": load_aria,
353}
354
355
356def run_generate(model, question: str, image_urls: List[str]):
357 req_data = model_example_map[model](question, image_urls)
358
359 sampling_params = SamplingParams(temperature=0.0,
360 max_tokens=128,
361 stop_token_ids=req_data.stop_token_ids)
362
363 outputs = req_data.llm.generate(
364 {
365 "prompt": req_data.prompt,
366 "multi_modal_data": {
367 "image": req_data.image_data
368 },
369 },
370 sampling_params=sampling_params)
371
372 for o in outputs:
373 generated_text = o.outputs[0].text
374 print(generated_text)
375
376
377def run_chat(model: str, question: str, image_urls: List[str]):
378 req_data = model_example_map[model](question, image_urls)
379
380 sampling_params = SamplingParams(temperature=0.0,
381 max_tokens=128,
382 stop_token_ids=req_data.stop_token_ids)
383 outputs = req_data.llm.chat(
384 [{
385 "role":
386 "user",
387 "content": [
388 {
389 "type": "text",
390 "text": question,
391 },
392 *({
393 "type": "image_url",
394 "image_url": {
395 "url": image_url
396 },
397 } for image_url in image_urls),
398 ],
399 }],
400 sampling_params=sampling_params,
401 chat_template=req_data.chat_template,
402 )
403
404 for o in outputs:
405 generated_text = o.outputs[0].text
406 print(generated_text)
407
408
409def main(args: Namespace):
410 model = args.model_type
411 method = args.method
412
413 if method == "generate":
414 run_generate(model, QUESTION, IMAGE_URLS)
415 elif method == "chat":
416 run_chat(model, QUESTION, IMAGE_URLS)
417 else:
418 raise ValueError(f"Invalid method: {method}")
419
420
421if __name__ == "__main__":
422 parser = FlexibleArgumentParser(
423 description='Demo on using vLLM for offline inference with '
424 'vision language models that support multi-image input for text '
425 'generation')
426 parser.add_argument('--model-type',
427 '-m',
428 type=str,
429 default="phi3_v",
430 choices=model_example_map.keys(),
431 help='Huggingface "model_type".')
432 parser.add_argument("--method",
433 type=str,
434 default="generate",
435 choices=["generate", "chat"],
436 help="The method to run in `vllm.LLM`.")
437
438 args = parser.parse_args()
439 main(args)