Offline Inference Vision Language Multi Image

Offline Inference Vision Language Multi Image#

Source vllm-project/vllm.

  1"""
  2This example shows how to use vLLM for running offline inference with
  3multi-image input on vision language models for text generation,
  4using the chat template defined by the model.
  5"""
  6from argparse import Namespace
  7from typing import List, NamedTuple, Optional
  8
  9from PIL.Image import Image
 10from transformers import AutoProcessor, AutoTokenizer
 11
 12from vllm import LLM, SamplingParams
 13from vllm.multimodal.utils import fetch_image
 14from vllm.utils import FlexibleArgumentParser
 15
 16QUESTION = "What is the content of each image?"
 17IMAGE_URLS = [
 18    "https://upload.wikimedia.org/wikipedia/commons/d/da/2015_Kaczka_krzy%C5%BCowka_w_wodzie_%28samiec%29.jpg",
 19    "https://upload.wikimedia.org/wikipedia/commons/7/77/002_The_lion_king_Snyggve_in_the_Serengeti_National_Park_Photo_by_Giles_Laurent.jpg",
 20]
 21
 22
 23class ModelRequestData(NamedTuple):
 24    llm: LLM
 25    prompt: str
 26    stop_token_ids: Optional[List[str]]
 27    image_data: List[Image]
 28    chat_template: Optional[str]
 29
 30
 31# NOTE: The default `max_num_seqs` and `max_model_len` may result in OOM on
 32# lower-end GPUs.
 33# Unless specified, these settings have been tested to work on a single L4.
 34
 35
 36def load_qwenvl_chat(question: str, image_urls: List[str]) -> ModelRequestData:
 37    model_name = "Qwen/Qwen-VL-Chat"
 38    llm = LLM(
 39        model=model_name,
 40        trust_remote_code=True,
 41        max_model_len=1024,
 42        max_num_seqs=2,
 43        limit_mm_per_prompt={"image": len(image_urls)},
 44    )
 45    placeholders = "".join(f"Picture {i}: <img></img>\n"
 46                           for i, _ in enumerate(image_urls, start=1))
 47
 48    # This model does not have a chat_template attribute on its tokenizer,
 49    # so we need to explicitly pass it. We use ChatML since it's used in the
 50    # generation utils of the model:
 51    # https://huggingface.co/Qwen/Qwen-VL-Chat/blob/main/qwen_generation_utils.py#L265
 52    tokenizer = AutoTokenizer.from_pretrained(model_name,
 53                                              trust_remote_code=True)
 54
 55    # Copied from: https://huggingface.co/docs/transformers/main/en/chat_templating
 56    chat_template = "{% if not add_generation_prompt is defined %}{% set add_generation_prompt = false %}{% endif %}{% for message in messages %}{{'<|im_start|>' + message['role'] + '\n' + message['content'] + '<|im_end|>' + '\n'}}{% endfor %}{% if add_generation_prompt %}{{ '<|im_start|>assistant\n' }}{% endif %}"  # noqa: E501
 57
 58    messages = [{'role': 'user', 'content': f"{placeholders}\n{question}"}]
 59    prompt = tokenizer.apply_chat_template(messages,
 60                                           tokenize=False,
 61                                           add_generation_prompt=True,
 62                                           chat_template=chat_template)
 63
 64    stop_tokens = ["<|endoftext|>", "<|im_start|>", "<|im_end|>"]
 65    stop_token_ids = [tokenizer.convert_tokens_to_ids(i) for i in stop_tokens]
 66    return ModelRequestData(
 67        llm=llm,
 68        prompt=prompt,
 69        stop_token_ids=stop_token_ids,
 70        image_data=[fetch_image(url) for url in image_urls],
 71        chat_template=chat_template,
 72    )
 73
 74
 75def load_phi3v(question: str, image_urls: List[str]) -> ModelRequestData:
 76    # num_crops is an override kwarg to the multimodal image processor;
 77    # For some models, e.g., Phi-3.5-vision-instruct, it is recommended
 78    # to use 16 for single frame scenarios, and 4 for multi-frame.
 79    #
 80    # Generally speaking, a larger value for num_crops results in more
 81    # tokens per image instance, because it may scale the image more in
 82    # the image preprocessing. Some references in the model docs and the
 83    # formula for image tokens after the preprocessing
 84    # transform can be found below.
 85    #
 86    # https://huggingface.co/microsoft/Phi-3.5-vision-instruct#loading-the-model-locally
 87    # https://huggingface.co/microsoft/Phi-3.5-vision-instruct/blob/main/processing_phi3_v.py#L194
 88    llm = LLM(
 89        model="microsoft/Phi-3.5-vision-instruct",
 90        trust_remote_code=True,
 91        max_model_len=4096,
 92        max_num_seqs=2,
 93        limit_mm_per_prompt={"image": len(image_urls)},
 94        mm_processor_kwargs={"num_crops": 4},
 95    )
 96    placeholders = "\n".join(f"<|image_{i}|>"
 97                             for i, _ in enumerate(image_urls, start=1))
 98    prompt = f"<|user|>\n{placeholders}\n{question}<|end|>\n<|assistant|>\n"
 99    stop_token_ids = None
100
101    return ModelRequestData(
102        llm=llm,
103        prompt=prompt,
104        stop_token_ids=stop_token_ids,
105        image_data=[fetch_image(url) for url in image_urls],
106        chat_template=None,
107    )
108
109
110def load_h2onvl(question: str, image_urls: List[str]) -> ModelRequestData:
111    model_name = "h2oai/h2ovl-mississippi-2b"
112
113    llm = LLM(
114        model=model_name,
115        trust_remote_code=True,
116        max_model_len=8192,
117        limit_mm_per_prompt={"image": len(image_urls)},
118        mm_processor_kwargs={"max_dynamic_patch": 4},
119    )
120
121    placeholders = "\n".join(f"Image-{i}: <image>\n"
122                             for i, _ in enumerate(image_urls, start=1))
123    messages = [{'role': 'user', 'content': f"{placeholders}\n{question}"}]
124
125    tokenizer = AutoTokenizer.from_pretrained(model_name,
126                                              trust_remote_code=True)
127    prompt = tokenizer.apply_chat_template(messages,
128                                           tokenize=False,
129                                           add_generation_prompt=True)
130
131    # Stop tokens for H2OVL-Mississippi
132    # https://huggingface.co/h2oai/h2ovl-mississippi-2b
133    stop_token_ids = [tokenizer.eos_token_id]
134
135    return ModelRequestData(
136        llm=llm,
137        prompt=prompt,
138        stop_token_ids=stop_token_ids,
139        image_data=[fetch_image(url) for url in image_urls],
140        chat_template=None,
141    )
142
143
144def load_internvl(question: str, image_urls: List[str]) -> ModelRequestData:
145    model_name = "OpenGVLab/InternVL2-2B"
146
147    llm = LLM(
148        model=model_name,
149        trust_remote_code=True,
150        max_model_len=4096,
151        limit_mm_per_prompt={"image": len(image_urls)},
152        mm_processor_kwargs={"max_dynamic_patch": 4},
153    )
154
155    placeholders = "\n".join(f"Image-{i}: <image>\n"
156                             for i, _ in enumerate(image_urls, start=1))
157    messages = [{'role': 'user', 'content': f"{placeholders}\n{question}"}]
158
159    tokenizer = AutoTokenizer.from_pretrained(model_name,
160                                              trust_remote_code=True)
161    prompt = tokenizer.apply_chat_template(messages,
162                                           tokenize=False,
163                                           add_generation_prompt=True)
164
165    # Stop tokens for InternVL
166    # models variants may have different stop tokens
167    # please refer to the model card for the correct "stop words":
168    # https://huggingface.co/OpenGVLab/InternVL2-2B/blob/main/conversation.py
169    stop_tokens = ["<|endoftext|>", "<|im_start|>", "<|im_end|>", "<|end|>"]
170    stop_token_ids = [tokenizer.convert_tokens_to_ids(i) for i in stop_tokens]
171
172    return ModelRequestData(
173        llm=llm,
174        prompt=prompt,
175        stop_token_ids=stop_token_ids,
176        image_data=[fetch_image(url) for url in image_urls],
177        chat_template=None,
178    )
179
180
181def load_nvlm_d(question: str, image_urls: List[str]):
182    model_name = "nvidia/NVLM-D-72B"
183
184    # Adjust this as necessary to fit in GPU
185    llm = LLM(
186        model=model_name,
187        trust_remote_code=True,
188        max_model_len=8192,
189        tensor_parallel_size=4,
190        limit_mm_per_prompt={"image": len(image_urls)},
191        mm_processor_kwargs={"max_dynamic_patch": 4},
192    )
193
194    placeholders = "\n".join(f"Image-{i}: <image>\n"
195                             for i, _ in enumerate(image_urls, start=1))
196    messages = [{'role': 'user', 'content': f"{placeholders}\n{question}"}]
197
198    tokenizer = AutoTokenizer.from_pretrained(model_name,
199                                              trust_remote_code=True)
200    prompt = tokenizer.apply_chat_template(messages,
201                                           tokenize=False,
202                                           add_generation_prompt=True)
203    stop_token_ids = None
204
205    return ModelRequestData(
206        llm=llm,
207        prompt=prompt,
208        stop_token_ids=stop_token_ids,
209        image_data=[fetch_image(url) for url in image_urls],
210        chat_template=None,
211    )
212
213
214def load_qwen2_vl(question, image_urls: List[str]) -> ModelRequestData:
215    try:
216        from qwen_vl_utils import process_vision_info
217    except ModuleNotFoundError:
218        print('WARNING: `qwen-vl-utils` not installed, input images will not '
219              'be automatically resized. You can enable this functionality by '
220              '`pip install qwen-vl-utils`.')
221        process_vision_info = None
222
223    model_name = "Qwen/Qwen2-VL-7B-Instruct"
224
225    # Tested on L40
226    llm = LLM(
227        model=model_name,
228        max_model_len=32768 if process_vision_info is None else 4096,
229        max_num_seqs=5,
230        limit_mm_per_prompt={"image": len(image_urls)},
231    )
232
233    placeholders = [{"type": "image", "image": url} for url in image_urls]
234    messages = [{
235        "role": "system",
236        "content": "You are a helpful assistant."
237    }, {
238        "role":
239        "user",
240        "content": [
241            *placeholders,
242            {
243                "type": "text",
244                "text": question
245            },
246        ],
247    }]
248
249    processor = AutoProcessor.from_pretrained(model_name)
250
251    prompt = processor.apply_chat_template(messages,
252                                           tokenize=False,
253                                           add_generation_prompt=True)
254
255    stop_token_ids = None
256
257    if process_vision_info is None:
258        image_data = [fetch_image(url) for url in image_urls]
259    else:
260        image_data, _ = process_vision_info(messages)
261
262    return ModelRequestData(
263        llm=llm,
264        prompt=prompt,
265        stop_token_ids=stop_token_ids,
266        image_data=image_data,
267        chat_template=None,
268    )
269
270
271def load_mllama(question, image_urls: List[str]) -> ModelRequestData:
272    model_name = "meta-llama/Llama-3.2-11B-Vision-Instruct"
273
274    # The configuration below has been confirmed to launch on a single L40 GPU.
275    llm = LLM(
276        model=model_name,
277        max_model_len=4096,
278        max_num_seqs=16,
279        enforce_eager=True,
280        limit_mm_per_prompt={"image": len(image_urls)},
281    )
282
283    prompt = f"<|image|><|image|><|begin_of_text|>{question}"
284    return ModelRequestData(
285        llm=llm,
286        prompt=prompt,
287        stop_token_ids=None,
288        image_data=[fetch_image(url) for url in image_urls],
289        chat_template=None,
290    )
291
292
293def load_idefics3(question, image_urls: List[str]) -> ModelRequestData:
294    model_name = "HuggingFaceM4/Idefics3-8B-Llama3"
295
296    # The configuration below has been confirmed to launch on a single L40 GPU.
297    llm = LLM(
298        model=model_name,
299        max_model_len=8192,
300        max_num_seqs=16,
301        enforce_eager=True,
302        limit_mm_per_prompt={"image": len(image_urls)},
303        # if you are running out of memory, you can reduce the "longest_edge".
304        # see: https://huggingface.co/HuggingFaceM4/Idefics3-8B-Llama3#model-optimizations
305        mm_processor_kwargs={
306            "size": {
307                "longest_edge": 2 * 364
308            },
309        },
310    )
311
312    placeholders = "\n".join(f"Image-{i}: <image>\n"
313                             for i, _ in enumerate(image_urls, start=1))
314    prompt = f"<|begin_of_text|>User:{placeholders}\n{question}<end_of_utterance>\nAssistant:"  # noqa: E501
315    return ModelRequestData(
316        llm=llm,
317        prompt=prompt,
318        stop_token_ids=None,
319        image_data=[fetch_image(url) for url in image_urls],
320        chat_template=None,
321    )
322
323
324def load_aria(question, image_urls: List[str]) -> ModelRequestData:
325    model_name = "rhymes-ai/Aria"
326    llm = LLM(model=model_name,
327              tokenizer_mode="slow",
328              trust_remote_code=True,
329              dtype="bfloat16",
330              limit_mm_per_prompt={"image": len(image_urls)})
331    placeholders = "<fim_prefix><|img|><fim_suffix>\n" * len(image_urls)
332    prompt = (f"<|im_start|>user\n{placeholders}{question}<|im_end|>\n"
333              "<|im_start|>assistant\n")
334    stop_token_ids = [93532, 93653, 944, 93421, 1019, 93653, 93519]
335    return ModelRequestData(
336        llm=llm,
337        prompt=prompt,
338        stop_token_ids=stop_token_ids,
339        image_data=[fetch_image(url) for url in image_urls],
340        chat_template=None)
341
342
343model_example_map = {
344    "phi3_v": load_phi3v,
345    "h2ovl_chat": load_h2onvl,
346    "internvl_chat": load_internvl,
347    "NVLM_D": load_nvlm_d,
348    "qwen2_vl": load_qwen2_vl,
349    "qwen_vl_chat": load_qwenvl_chat,
350    "mllama": load_mllama,
351    "idefics3": load_idefics3,
352    "aria": load_aria,
353}
354
355
356def run_generate(model, question: str, image_urls: List[str]):
357    req_data = model_example_map[model](question, image_urls)
358
359    sampling_params = SamplingParams(temperature=0.0,
360                                     max_tokens=128,
361                                     stop_token_ids=req_data.stop_token_ids)
362
363    outputs = req_data.llm.generate(
364        {
365            "prompt": req_data.prompt,
366            "multi_modal_data": {
367                "image": req_data.image_data
368            },
369        },
370        sampling_params=sampling_params)
371
372    for o in outputs:
373        generated_text = o.outputs[0].text
374        print(generated_text)
375
376
377def run_chat(model: str, question: str, image_urls: List[str]):
378    req_data = model_example_map[model](question, image_urls)
379
380    sampling_params = SamplingParams(temperature=0.0,
381                                     max_tokens=128,
382                                     stop_token_ids=req_data.stop_token_ids)
383    outputs = req_data.llm.chat(
384        [{
385            "role":
386            "user",
387            "content": [
388                {
389                    "type": "text",
390                    "text": question,
391                },
392                *({
393                    "type": "image_url",
394                    "image_url": {
395                        "url": image_url
396                    },
397                } for image_url in image_urls),
398            ],
399        }],
400        sampling_params=sampling_params,
401        chat_template=req_data.chat_template,
402    )
403
404    for o in outputs:
405        generated_text = o.outputs[0].text
406        print(generated_text)
407
408
409def main(args: Namespace):
410    model = args.model_type
411    method = args.method
412
413    if method == "generate":
414        run_generate(model, QUESTION, IMAGE_URLS)
415    elif method == "chat":
416        run_chat(model, QUESTION, IMAGE_URLS)
417    else:
418        raise ValueError(f"Invalid method: {method}")
419
420
421if __name__ == "__main__":
422    parser = FlexibleArgumentParser(
423        description='Demo on using vLLM for offline inference with '
424        'vision language models that support multi-image input for text '
425        'generation')
426    parser.add_argument('--model-type',
427                        '-m',
428                        type=str,
429                        default="phi3_v",
430                        choices=model_example_map.keys(),
431                        help='Huggingface "model_type".')
432    parser.add_argument("--method",
433                        type=str,
434                        default="generate",
435                        choices=["generate", "chat"],
436                        help="The method to run in `vllm.LLM`.")
437
438    args = parser.parse_args()
439    main(args)