Offline Inference Vision Language Multi Image

Offline Inference Vision Language Multi Image#

Source: examples/offline_inference_vision_language_multi_image.py.

  1"""
  2This example shows how to use vLLM for running offline inference with
  3multi-image input on vision language models for text generation,
  4using the chat template defined by the model.
  5"""
  6from argparse import Namespace
  7from typing import List, NamedTuple, Optional
  8
  9from PIL.Image import Image
 10from transformers import AutoProcessor, AutoTokenizer
 11
 12from vllm import LLM, SamplingParams
 13from vllm.multimodal.utils import fetch_image
 14from vllm.utils import FlexibleArgumentParser
 15
 16QUESTION = "What is the content of each image?"
 17IMAGE_URLS = [
 18    "https://upload.wikimedia.org/wikipedia/commons/d/da/2015_Kaczka_krzy%C5%BCowka_w_wodzie_%28samiec%29.jpg",
 19    "https://upload.wikimedia.org/wikipedia/commons/7/77/002_The_lion_king_Snyggve_in_the_Serengeti_National_Park_Photo_by_Giles_Laurent.jpg",
 20]
 21
 22
 23class ModelRequestData(NamedTuple):
 24    llm: LLM
 25    prompt: str
 26    stop_token_ids: Optional[List[str]]
 27    image_data: List[Image]
 28    chat_template: Optional[str]
 29
 30
 31# NOTE: The default `max_num_seqs` and `max_model_len` may result in OOM on
 32# lower-end GPUs.
 33# Unless specified, these settings have been tested to work on a single L4.
 34
 35
 36def load_aria(question, image_urls: List[str]) -> ModelRequestData:
 37    model_name = "rhymes-ai/Aria"
 38    llm = LLM(model=model_name,
 39              tokenizer_mode="slow",
 40              trust_remote_code=True,
 41              dtype="bfloat16",
 42              limit_mm_per_prompt={"image": len(image_urls)})
 43    placeholders = "<fim_prefix><|img|><fim_suffix>\n" * len(image_urls)
 44    prompt = (f"<|im_start|>user\n{placeholders}{question}<|im_end|>\n"
 45              "<|im_start|>assistant\n")
 46    stop_token_ids = [93532, 93653, 944, 93421, 1019, 93653, 93519]
 47    return ModelRequestData(
 48        llm=llm,
 49        prompt=prompt,
 50        stop_token_ids=stop_token_ids,
 51        image_data=[fetch_image(url) for url in image_urls],
 52        chat_template=None)
 53
 54
 55def load_h2onvl(question: str, image_urls: List[str]) -> ModelRequestData:
 56    model_name = "h2oai/h2ovl-mississippi-2b"
 57
 58    llm = LLM(
 59        model=model_name,
 60        trust_remote_code=True,
 61        max_model_len=8192,
 62        limit_mm_per_prompt={"image": len(image_urls)},
 63        mm_processor_kwargs={"max_dynamic_patch": 4},
 64    )
 65
 66    placeholders = "\n".join(f"Image-{i}: <image>\n"
 67                             for i, _ in enumerate(image_urls, start=1))
 68    messages = [{'role': 'user', 'content': f"{placeholders}\n{question}"}]
 69
 70    tokenizer = AutoTokenizer.from_pretrained(model_name,
 71                                              trust_remote_code=True)
 72    prompt = tokenizer.apply_chat_template(messages,
 73                                           tokenize=False,
 74                                           add_generation_prompt=True)
 75
 76    # Stop tokens for H2OVL-Mississippi
 77    # https://huggingface.co/h2oai/h2ovl-mississippi-2b
 78    stop_token_ids = [tokenizer.eos_token_id]
 79
 80    return ModelRequestData(
 81        llm=llm,
 82        prompt=prompt,
 83        stop_token_ids=stop_token_ids,
 84        image_data=[fetch_image(url) for url in image_urls],
 85        chat_template=None,
 86    )
 87
 88
 89def load_idefics3(question, image_urls: List[str]) -> ModelRequestData:
 90    model_name = "HuggingFaceM4/Idefics3-8B-Llama3"
 91
 92    # The configuration below has been confirmed to launch on a single L40 GPU.
 93    llm = LLM(
 94        model=model_name,
 95        max_model_len=8192,
 96        max_num_seqs=16,
 97        enforce_eager=True,
 98        limit_mm_per_prompt={"image": len(image_urls)},
 99        # if you are running out of memory, you can reduce the "longest_edge".
100        # see: https://huggingface.co/HuggingFaceM4/Idefics3-8B-Llama3#model-optimizations
101        mm_processor_kwargs={
102            "size": {
103                "longest_edge": 2 * 364
104            },
105        },
106    )
107
108    placeholders = "\n".join(f"Image-{i}: <image>\n"
109                             for i, _ in enumerate(image_urls, start=1))
110    prompt = f"<|begin_of_text|>User:{placeholders}\n{question}<end_of_utterance>\nAssistant:"  # noqa: E501
111    return ModelRequestData(
112        llm=llm,
113        prompt=prompt,
114        stop_token_ids=None,
115        image_data=[fetch_image(url) for url in image_urls],
116        chat_template=None,
117    )
118
119
120def load_internvl(question: str, image_urls: List[str]) -> ModelRequestData:
121    model_name = "OpenGVLab/InternVL2-2B"
122
123    llm = LLM(
124        model=model_name,
125        trust_remote_code=True,
126        max_model_len=4096,
127        limit_mm_per_prompt={"image": len(image_urls)},
128        mm_processor_kwargs={"max_dynamic_patch": 4},
129    )
130
131    placeholders = "\n".join(f"Image-{i}: <image>\n"
132                             for i, _ in enumerate(image_urls, start=1))
133    messages = [{'role': 'user', 'content': f"{placeholders}\n{question}"}]
134
135    tokenizer = AutoTokenizer.from_pretrained(model_name,
136                                              trust_remote_code=True)
137    prompt = tokenizer.apply_chat_template(messages,
138                                           tokenize=False,
139                                           add_generation_prompt=True)
140
141    # Stop tokens for InternVL
142    # models variants may have different stop tokens
143    # please refer to the model card for the correct "stop words":
144    # https://huggingface.co/OpenGVLab/InternVL2-2B/blob/main/conversation.py
145    stop_tokens = ["<|endoftext|>", "<|im_start|>", "<|im_end|>", "<|end|>"]
146    stop_token_ids = [tokenizer.convert_tokens_to_ids(i) for i in stop_tokens]
147
148    return ModelRequestData(
149        llm=llm,
150        prompt=prompt,
151        stop_token_ids=stop_token_ids,
152        image_data=[fetch_image(url) for url in image_urls],
153        chat_template=None,
154    )
155
156
157def load_mllama(question, image_urls: List[str]) -> ModelRequestData:
158    model_name = "meta-llama/Llama-3.2-11B-Vision-Instruct"
159
160    # The configuration below has been confirmed to launch on a single L40 GPU.
161    llm = LLM(
162        model=model_name,
163        max_model_len=4096,
164        max_num_seqs=16,
165        enforce_eager=True,
166        limit_mm_per_prompt={"image": len(image_urls)},
167    )
168
169    prompt = f"<|image|><|image|><|begin_of_text|>{question}"
170    return ModelRequestData(
171        llm=llm,
172        prompt=prompt,
173        stop_token_ids=None,
174        image_data=[fetch_image(url) for url in image_urls],
175        chat_template=None,
176    )
177
178
179def load_nvlm_d(question: str, image_urls: List[str]):
180    model_name = "nvidia/NVLM-D-72B"
181
182    # Adjust this as necessary to fit in GPU
183    llm = LLM(
184        model=model_name,
185        trust_remote_code=True,
186        max_model_len=8192,
187        tensor_parallel_size=4,
188        limit_mm_per_prompt={"image": len(image_urls)},
189        mm_processor_kwargs={"max_dynamic_patch": 4},
190    )
191
192    placeholders = "\n".join(f"Image-{i}: <image>\n"
193                             for i, _ in enumerate(image_urls, start=1))
194    messages = [{'role': 'user', 'content': f"{placeholders}\n{question}"}]
195
196    tokenizer = AutoTokenizer.from_pretrained(model_name,
197                                              trust_remote_code=True)
198    prompt = tokenizer.apply_chat_template(messages,
199                                           tokenize=False,
200                                           add_generation_prompt=True)
201    stop_token_ids = None
202
203    return ModelRequestData(
204        llm=llm,
205        prompt=prompt,
206        stop_token_ids=stop_token_ids,
207        image_data=[fetch_image(url) for url in image_urls],
208        chat_template=None,
209    )
210
211
212def load_phi3v(question: str, image_urls: List[str]) -> ModelRequestData:
213    # num_crops is an override kwarg to the multimodal image processor;
214    # For some models, e.g., Phi-3.5-vision-instruct, it is recommended
215    # to use 16 for single frame scenarios, and 4 for multi-frame.
216    #
217    # Generally speaking, a larger value for num_crops results in more
218    # tokens per image instance, because it may scale the image more in
219    # the image preprocessing. Some references in the model docs and the
220    # formula for image tokens after the preprocessing
221    # transform can be found below.
222    #
223    # https://huggingface.co/microsoft/Phi-3.5-vision-instruct#loading-the-model-locally
224    # https://huggingface.co/microsoft/Phi-3.5-vision-instruct/blob/main/processing_phi3_v.py#L194
225    llm = LLM(
226        model="microsoft/Phi-3.5-vision-instruct",
227        trust_remote_code=True,
228        max_model_len=4096,
229        max_num_seqs=2,
230        limit_mm_per_prompt={"image": len(image_urls)},
231        mm_processor_kwargs={"num_crops": 4},
232    )
233    placeholders = "\n".join(f"<|image_{i}|>"
234                             for i, _ in enumerate(image_urls, start=1))
235    prompt = f"<|user|>\n{placeholders}\n{question}<|end|>\n<|assistant|>\n"
236    stop_token_ids = None
237
238    return ModelRequestData(
239        llm=llm,
240        prompt=prompt,
241        stop_token_ids=stop_token_ids,
242        image_data=[fetch_image(url) for url in image_urls],
243        chat_template=None,
244    )
245
246
247def load_qwenvl_chat(question: str, image_urls: List[str]) -> ModelRequestData:
248    model_name = "Qwen/Qwen-VL-Chat"
249    llm = LLM(
250        model=model_name,
251        trust_remote_code=True,
252        max_model_len=1024,
253        max_num_seqs=2,
254        limit_mm_per_prompt={"image": len(image_urls)},
255    )
256    placeholders = "".join(f"Picture {i}: <img></img>\n"
257                           for i, _ in enumerate(image_urls, start=1))
258
259    # This model does not have a chat_template attribute on its tokenizer,
260    # so we need to explicitly pass it. We use ChatML since it's used in the
261    # generation utils of the model:
262    # https://huggingface.co/Qwen/Qwen-VL-Chat/blob/main/qwen_generation_utils.py#L265
263    tokenizer = AutoTokenizer.from_pretrained(model_name,
264                                              trust_remote_code=True)
265
266    # Copied from: https://huggingface.co/docs/transformers/main/en/chat_templating
267    chat_template = "{% if not add_generation_prompt is defined %}{% set add_generation_prompt = false %}{% endif %}{% for message in messages %}{{'<|im_start|>' + message['role'] + '\n' + message['content'] + '<|im_end|>' + '\n'}}{% endfor %}{% if add_generation_prompt %}{{ '<|im_start|>assistant\n' }}{% endif %}"  # noqa: E501
268
269    messages = [{'role': 'user', 'content': f"{placeholders}\n{question}"}]
270    prompt = tokenizer.apply_chat_template(messages,
271                                           tokenize=False,
272                                           add_generation_prompt=True,
273                                           chat_template=chat_template)
274
275    stop_tokens = ["<|endoftext|>", "<|im_start|>", "<|im_end|>"]
276    stop_token_ids = [tokenizer.convert_tokens_to_ids(i) for i in stop_tokens]
277    return ModelRequestData(
278        llm=llm,
279        prompt=prompt,
280        stop_token_ids=stop_token_ids,
281        image_data=[fetch_image(url) for url in image_urls],
282        chat_template=chat_template,
283    )
284
285
286def load_qwen2_vl(question, image_urls: List[str]) -> ModelRequestData:
287    try:
288        from qwen_vl_utils import process_vision_info
289    except ModuleNotFoundError:
290        print('WARNING: `qwen-vl-utils` not installed, input images will not '
291              'be automatically resized. You can enable this functionality by '
292              '`pip install qwen-vl-utils`.')
293        process_vision_info = None
294
295    model_name = "Qwen/Qwen2-VL-7B-Instruct"
296
297    # Tested on L40
298    llm = LLM(
299        model=model_name,
300        max_model_len=32768 if process_vision_info is None else 4096,
301        max_num_seqs=5,
302        limit_mm_per_prompt={"image": len(image_urls)},
303    )
304
305    placeholders = [{"type": "image", "image": url} for url in image_urls]
306    messages = [{
307        "role": "system",
308        "content": "You are a helpful assistant."
309    }, {
310        "role":
311        "user",
312        "content": [
313            *placeholders,
314            {
315                "type": "text",
316                "text": question
317            },
318        ],
319    }]
320
321    processor = AutoProcessor.from_pretrained(model_name)
322
323    prompt = processor.apply_chat_template(messages,
324                                           tokenize=False,
325                                           add_generation_prompt=True)
326
327    stop_token_ids = None
328
329    if process_vision_info is None:
330        image_data = [fetch_image(url) for url in image_urls]
331    else:
332        image_data, _ = process_vision_info(messages)
333
334    return ModelRequestData(
335        llm=llm,
336        prompt=prompt,
337        stop_token_ids=stop_token_ids,
338        image_data=image_data,
339        chat_template=None,
340    )
341
342
343model_example_map = {
344    "aria": load_aria,
345    "h2ovl_chat": load_h2onvl,
346    "idefics3": load_idefics3,
347    "internvl_chat": load_internvl,
348    "mllama": load_mllama,
349    "NVLM_D": load_nvlm_d,
350    "phi3_v": load_phi3v,
351    "qwen_vl_chat": load_qwenvl_chat,
352    "qwen2_vl": load_qwen2_vl,
353}
354
355
356def run_generate(model, question: str, image_urls: List[str]):
357    req_data = model_example_map[model](question, image_urls)
358
359    sampling_params = SamplingParams(temperature=0.0,
360                                     max_tokens=128,
361                                     stop_token_ids=req_data.stop_token_ids)
362
363    outputs = req_data.llm.generate(
364        {
365            "prompt": req_data.prompt,
366            "multi_modal_data": {
367                "image": req_data.image_data
368            },
369        },
370        sampling_params=sampling_params)
371
372    for o in outputs:
373        generated_text = o.outputs[0].text
374        print(generated_text)
375
376
377def run_chat(model: str, question: str, image_urls: List[str]):
378    req_data = model_example_map[model](question, image_urls)
379
380    sampling_params = SamplingParams(temperature=0.0,
381                                     max_tokens=128,
382                                     stop_token_ids=req_data.stop_token_ids)
383    outputs = req_data.llm.chat(
384        [{
385            "role":
386            "user",
387            "content": [
388                {
389                    "type": "text",
390                    "text": question,
391                },
392                *({
393                    "type": "image_url",
394                    "image_url": {
395                        "url": image_url
396                    },
397                } for image_url in image_urls),
398            ],
399        }],
400        sampling_params=sampling_params,
401        chat_template=req_data.chat_template,
402    )
403
404    for o in outputs:
405        generated_text = o.outputs[0].text
406        print(generated_text)
407
408
409def main(args: Namespace):
410    model = args.model_type
411    method = args.method
412
413    if method == "generate":
414        run_generate(model, QUESTION, IMAGE_URLS)
415    elif method == "chat":
416        run_chat(model, QUESTION, IMAGE_URLS)
417    else:
418        raise ValueError(f"Invalid method: {method}")
419
420
421if __name__ == "__main__":
422    parser = FlexibleArgumentParser(
423        description='Demo on using vLLM for offline inference with '
424        'vision language models that support multi-image input for text '
425        'generation')
426    parser.add_argument('--model-type',
427                        '-m',
428                        type=str,
429                        default="phi3_v",
430                        choices=model_example_map.keys(),
431                        help='Huggingface "model_type".')
432    parser.add_argument("--method",
433                        type=str,
434                        default="generate",
435                        choices=["generate", "chat"],
436                        help="The method to run in `vllm.LLM`.")
437
438    args = parser.parse_args()
439    main(args)