Offline Inference Vision Language

Offline Inference Vision Language#

Source: examples/offline_inference_vision_language.py.

  1"""
  2This example shows how to use vLLM for running offline inference with
  3the correct prompt format on vision language models for text generation.
  4
  5For most models, the prompt format should follow corresponding examples
  6on HuggingFace model repository.
  7"""
  8import random
  9
 10from transformers import AutoTokenizer
 11
 12from vllm import LLM, SamplingParams
 13from vllm.assets.image import ImageAsset
 14from vllm.assets.video import VideoAsset
 15from vllm.utils import FlexibleArgumentParser
 16
 17# NOTE: The default `max_num_seqs` and `max_model_len` may result in OOM on
 18# lower-end GPUs.
 19# Unless specified, these settings have been tested to work on a single L4.
 20
 21
 22# Aria
 23def run_aria(question: str, modality: str):
 24    assert modality == "image"
 25    model_name = "rhymes-ai/Aria"
 26
 27    llm = LLM(model=model_name,
 28              tokenizer_mode="slow",
 29              trust_remote_code=True,
 30              dtype="bfloat16",
 31              disable_mm_preprocessor_cache=args.disable_mm_preprocessor_cache)
 32
 33    prompt = (f"<|im_start|>user\n<fim_prefix><|img|><fim_suffix>\n{question}"
 34              "<|im_end|>\n<|im_start|>assistant\n")
 35
 36    stop_token_ids = [93532, 93653, 944, 93421, 1019, 93653, 93519]
 37    return llm, prompt, stop_token_ids
 38
 39
 40# BLIP-2
 41def run_blip2(question: str, modality: str):
 42    assert modality == "image"
 43
 44    # BLIP-2 prompt format is inaccurate on HuggingFace model repository.
 45    # See https://huggingface.co/Salesforce/blip2-opt-2.7b/discussions/15#64ff02f3f8cf9e4f5b038262 #noqa
 46    prompt = f"Question: {question} Answer:"
 47    llm = LLM(model="Salesforce/blip2-opt-2.7b",
 48              disable_mm_preprocessor_cache=args.disable_mm_preprocessor_cache)
 49    stop_token_ids = None
 50    return llm, prompt, stop_token_ids
 51
 52
 53# Chameleon
 54def run_chameleon(question: str, modality: str):
 55    assert modality == "image"
 56
 57    prompt = f"{question}<image>"
 58    llm = LLM(model="facebook/chameleon-7b",
 59              max_model_len=4096,
 60              disable_mm_preprocessor_cache=args.disable_mm_preprocessor_cache)
 61    stop_token_ids = None
 62    return llm, prompt, stop_token_ids
 63
 64
 65# Fuyu
 66def run_fuyu(question: str, modality: str):
 67    assert modality == "image"
 68
 69    prompt = f"{question}\n"
 70    llm = LLM(model="adept/fuyu-8b",
 71              max_model_len=2048,
 72              max_num_seqs=2,
 73              disable_mm_preprocessor_cache=args.disable_mm_preprocessor_cache)
 74    stop_token_ids = None
 75    return llm, prompt, stop_token_ids
 76
 77
 78# GLM-4v
 79def run_glm4v(question: str, modality: str):
 80    assert modality == "image"
 81    model_name = "THUDM/glm-4v-9b"
 82
 83    llm = LLM(model=model_name,
 84              max_model_len=2048,
 85              max_num_seqs=2,
 86              trust_remote_code=True,
 87              enforce_eager=True,
 88              disable_mm_preprocessor_cache=args.disable_mm_preprocessor_cache)
 89    prompt = question
 90    stop_token_ids = [151329, 151336, 151338]
 91    return llm, prompt, stop_token_ids
 92
 93
 94# H2OVL-Mississippi
 95def run_h2ovl(question: str, modality: str):
 96    assert modality == "image"
 97
 98    model_name = "h2oai/h2ovl-mississippi-2b"
 99
100    llm = LLM(
101        model=model_name,
102        trust_remote_code=True,
103        max_model_len=8192,
104        disable_mm_preprocessor_cache=args.disable_mm_preprocessor_cache,
105    )
106
107    tokenizer = AutoTokenizer.from_pretrained(model_name,
108                                              trust_remote_code=True)
109    messages = [{'role': 'user', 'content': f"<image>\n{question}"}]
110    prompt = tokenizer.apply_chat_template(messages,
111                                           tokenize=False,
112                                           add_generation_prompt=True)
113
114    # Stop tokens for H2OVL-Mississippi
115    # https://huggingface.co/h2oai/h2ovl-mississippi-2b
116    stop_token_ids = [tokenizer.eos_token_id]
117    return llm, prompt, stop_token_ids
118
119
120# Idefics3-8B-Llama3
121def run_idefics3(question: str, modality: str):
122    assert modality == "image"
123    model_name = "HuggingFaceM4/Idefics3-8B-Llama3"
124
125    llm = LLM(
126        model=model_name,
127        max_model_len=8192,
128        max_num_seqs=2,
129        enforce_eager=True,
130        # if you are running out of memory, you can reduce the "longest_edge".
131        # see: https://huggingface.co/HuggingFaceM4/Idefics3-8B-Llama3#model-optimizations
132        mm_processor_kwargs={
133            "size": {
134                "longest_edge": 3 * 364
135            },
136        },
137        disable_mm_preprocessor_cache=args.disable_mm_preprocessor_cache,
138    )
139    prompt = (
140        f"<|begin_of_text|>User:<image>{question}<end_of_utterance>\nAssistant:"
141    )
142    stop_token_ids = None
143    return llm, prompt, stop_token_ids
144
145
146# InternVL
147def run_internvl(question: str, modality: str):
148    assert modality == "image"
149
150    model_name = "OpenGVLab/InternVL2-2B"
151
152    llm = LLM(
153        model=model_name,
154        trust_remote_code=True,
155        max_model_len=4096,
156        disable_mm_preprocessor_cache=args.disable_mm_preprocessor_cache,
157    )
158
159    tokenizer = AutoTokenizer.from_pretrained(model_name,
160                                              trust_remote_code=True)
161    messages = [{'role': 'user', 'content': f"<image>\n{question}"}]
162    prompt = tokenizer.apply_chat_template(messages,
163                                           tokenize=False,
164                                           add_generation_prompt=True)
165
166    # Stop tokens for InternVL
167    # models variants may have different stop tokens
168    # please refer to the model card for the correct "stop words":
169    # https://huggingface.co/OpenGVLab/InternVL2-2B/blob/main/conversation.py
170    stop_tokens = ["<|endoftext|>", "<|im_start|>", "<|im_end|>", "<|end|>"]
171    stop_token_ids = [tokenizer.convert_tokens_to_ids(i) for i in stop_tokens]
172    return llm, prompt, stop_token_ids
173
174
175# LLaVA-1.5
176def run_llava(question: str, modality: str):
177    assert modality == "image"
178
179    prompt = f"USER: <image>\n{question}\nASSISTANT:"
180
181    llm = LLM(model="llava-hf/llava-1.5-7b-hf",
182              max_model_len=4096,
183              disable_mm_preprocessor_cache=args.disable_mm_preprocessor_cache)
184    stop_token_ids = None
185    return llm, prompt, stop_token_ids
186
187
188# LLaVA-1.6/LLaVA-NeXT
189def run_llava_next(question: str, modality: str):
190    assert modality == "image"
191
192    prompt = f"[INST] <image>\n{question} [/INST]"
193    llm = LLM(model="llava-hf/llava-v1.6-mistral-7b-hf",
194              max_model_len=8192,
195              disable_mm_preprocessor_cache=args.disable_mm_preprocessor_cache)
196    stop_token_ids = None
197    return llm, prompt, stop_token_ids
198
199
200# LlaVA-NeXT-Video
201# Currently only support for video input
202def run_llava_next_video(question: str, modality: str):
203    assert modality == "video"
204
205    prompt = f"USER: <video>\n{question} ASSISTANT:"
206    llm = LLM(model="llava-hf/LLaVA-NeXT-Video-7B-hf",
207              max_model_len=8192,
208              disable_mm_preprocessor_cache=args.disable_mm_preprocessor_cache)
209    stop_token_ids = None
210    return llm, prompt, stop_token_ids
211
212
213# LLaVA-OneVision
214def run_llava_onevision(question: str, modality: str):
215
216    if modality == "video":
217        prompt = f"<|im_start|>user <video>\n{question}<|im_end|> \
218        <|im_start|>assistant\n"
219
220    elif modality == "image":
221        prompt = f"<|im_start|>user <image>\n{question}<|im_end|> \
222        <|im_start|>assistant\n"
223
224    llm = LLM(model="llava-hf/llava-onevision-qwen2-7b-ov-hf",
225              max_model_len=16384,
226              disable_mm_preprocessor_cache=args.disable_mm_preprocessor_cache)
227    stop_token_ids = None
228    return llm, prompt, stop_token_ids
229
230
231# Mantis
232def run_mantis(question: str, modality: str):
233    assert modality == "image"
234
235    llama3_template = '<|start_header_id|>user<|end_header_id|>\n\n{}<|eot_id|><|start_header_id|>assistant<|end_header_id|>\n\n'  # noqa: E501
236    prompt = llama3_template.format(f"{question}\n<image>")
237
238    llm = LLM(
239        model="TIGER-Lab/Mantis-8B-siglip-llama3",
240        max_model_len=4096,
241        hf_overrides={"architectures": ["MantisForConditionalGeneration"]},
242        disable_mm_preprocessor_cache=args.disable_mm_preprocessor_cache,
243    )
244    stop_token_ids = [128009]
245    return llm, prompt, stop_token_ids
246
247
248# MiniCPM-V
249def run_minicpmv(question: str, modality: str):
250    assert modality == "image"
251
252    # 2.0
253    # The official repo doesn't work yet, so we need to use a fork for now
254    # For more details, please see: See: https://github.com/vllm-project/vllm/pull/4087#issuecomment-2250397630 # noqa
255    # model_name = "HwwwH/MiniCPM-V-2"
256
257    # 2.5
258    # model_name = "openbmb/MiniCPM-Llama3-V-2_5"
259
260    #2.6
261    model_name = "openbmb/MiniCPM-V-2_6"
262    tokenizer = AutoTokenizer.from_pretrained(model_name,
263                                              trust_remote_code=True)
264    llm = LLM(
265        model=model_name,
266        max_model_len=4096,
267        max_num_seqs=2,
268        trust_remote_code=True,
269        disable_mm_preprocessor_cache=args.disable_mm_preprocessor_cache,
270    )
271    # NOTE The stop_token_ids are different for various versions of MiniCPM-V
272    # 2.0
273    # stop_token_ids = [tokenizer.eos_id]
274
275    # 2.5
276    # stop_token_ids = [tokenizer.eos_id, tokenizer.eot_id]
277
278    # 2.6
279    stop_tokens = ['<|im_end|>', '<|endoftext|>']
280    stop_token_ids = [tokenizer.convert_tokens_to_ids(i) for i in stop_tokens]
281
282    messages = [{
283        'role': 'user',
284        'content': f'(<image>./</image>)\n{question}'
285    }]
286    prompt = tokenizer.apply_chat_template(messages,
287                                           tokenize=False,
288                                           add_generation_prompt=True)
289    return llm, prompt, stop_token_ids
290
291
292# LLama 3.2
293def run_mllama(question: str, modality: str):
294    assert modality == "image"
295
296    model_name = "meta-llama/Llama-3.2-11B-Vision-Instruct"
297
298    # Note: The default setting of max_num_seqs (256) and
299    # max_model_len (131072) for this model may cause OOM.
300    # You may lower either to run this example on lower-end GPUs.
301
302    # The configuration below has been confirmed to launch on a single L40 GPU.
303    llm = LLM(
304        model=model_name,
305        max_model_len=4096,
306        max_num_seqs=16,
307        enforce_eager=True,
308        disable_mm_preprocessor_cache=args.disable_mm_preprocessor_cache,
309    )
310
311    prompt = f"<|image|><|begin_of_text|>{question}"
312    stop_token_ids = None
313    return llm, prompt, stop_token_ids
314
315
316# Molmo
317def run_molmo(question, modality):
318    assert modality == "image"
319
320    model_name = "allenai/Molmo-7B-D-0924"
321
322    llm = LLM(
323        model=model_name,
324        trust_remote_code=True,
325        dtype="bfloat16",
326        disable_mm_preprocessor_cache=args.disable_mm_preprocessor_cache,
327    )
328
329    prompt = question
330    stop_token_ids = None
331    return llm, prompt, stop_token_ids
332
333
334# NVLM-D
335def run_nvlm_d(question: str, modality: str):
336    assert modality == "image"
337
338    model_name = "nvidia/NVLM-D-72B"
339
340    # Adjust this as necessary to fit in GPU
341    llm = LLM(
342        model=model_name,
343        trust_remote_code=True,
344        max_model_len=4096,
345        tensor_parallel_size=4,
346        disable_mm_preprocessor_cache=args.disable_mm_preprocessor_cache,
347    )
348
349    tokenizer = AutoTokenizer.from_pretrained(model_name,
350                                              trust_remote_code=True)
351    messages = [{'role': 'user', 'content': f"<image>\n{question}"}]
352    prompt = tokenizer.apply_chat_template(messages,
353                                           tokenize=False,
354                                           add_generation_prompt=True)
355    stop_token_ids = None
356    return llm, prompt, stop_token_ids
357
358
359# PaliGemma
360def run_paligemma(question: str, modality: str):
361    assert modality == "image"
362
363    # PaliGemma has special prompt format for VQA
364    prompt = "caption en"
365    llm = LLM(model="google/paligemma-3b-mix-224",
366              disable_mm_preprocessor_cache=args.disable_mm_preprocessor_cache)
367    stop_token_ids = None
368    return llm, prompt, stop_token_ids
369
370
371# PaliGemma 2
372def run_paligemma2(question: str, modality: str):
373    assert modality == "image"
374
375    # PaliGemma 2 has special prompt format for VQA
376    prompt = "caption en"
377    llm = LLM(model="google/paligemma2-3b-ft-docci-448",
378              disable_mm_preprocessor_cache=args.disable_mm_preprocessor_cache)
379    stop_token_ids = None
380    return llm, prompt, stop_token_ids
381
382
383# Phi-3-Vision
384def run_phi3v(question: str, modality: str):
385    assert modality == "image"
386
387    prompt = f"<|user|>\n<|image_1|>\n{question}<|end|>\n<|assistant|>\n"
388
389    # num_crops is an override kwarg to the multimodal image processor;
390    # For some models, e.g., Phi-3.5-vision-instruct, it is recommended
391    # to use 16 for single frame scenarios, and 4 for multi-frame.
392    #
393    # Generally speaking, a larger value for num_crops results in more
394    # tokens per image instance, because it may scale the image more in
395    # the image preprocessing. Some references in the model docs and the
396    # formula for image tokens after the preprocessing
397    # transform can be found below.
398    #
399    # https://huggingface.co/microsoft/Phi-3.5-vision-instruct#loading-the-model-locally
400    # https://huggingface.co/microsoft/Phi-3.5-vision-instruct/blob/main/processing_phi3_v.py#L194
401    llm = LLM(
402        model="microsoft/Phi-3.5-vision-instruct",
403        trust_remote_code=True,
404        max_model_len=4096,
405        max_num_seqs=2,
406        # Note - mm_processor_kwargs can also be passed to generate/chat calls
407        mm_processor_kwargs={"num_crops": 16},
408        disable_mm_preprocessor_cache=args.disable_mm_preprocessor_cache,
409    )
410    stop_token_ids = None
411    return llm, prompt, stop_token_ids
412
413
414# Pixtral HF-format
415def run_pixtral_hf(question: str, modality: str):
416    assert modality == "image"
417
418    model_name = "mistral-community/pixtral-12b"
419
420    llm = LLM(
421        model=model_name,
422        max_model_len=8192,
423        disable_mm_preprocessor_cache=args.disable_mm_preprocessor_cache,
424    )
425
426    prompt = f"<s>[INST]{question}\n[IMG][/INST]"
427    stop_token_ids = None
428    return llm, prompt, stop_token_ids
429
430
431# Qwen
432def run_qwen_vl(question: str, modality: str):
433    assert modality == "image"
434
435    llm = LLM(
436        model="Qwen/Qwen-VL",
437        trust_remote_code=True,
438        max_model_len=1024,
439        max_num_seqs=2,
440        disable_mm_preprocessor_cache=args.disable_mm_preprocessor_cache,
441    )
442
443    prompt = f"{question}Picture 1: <img></img>\n"
444    stop_token_ids = None
445    return llm, prompt, stop_token_ids
446
447
448# Qwen2-VL
449def run_qwen2_vl(question: str, modality: str):
450
451    model_name = "Qwen/Qwen2-VL-7B-Instruct"
452
453    llm = LLM(
454        model=model_name,
455        max_model_len=4096,
456        max_num_seqs=5,
457        # Note - mm_processor_kwargs can also be passed to generate/chat calls
458        mm_processor_kwargs={
459            "min_pixels": 28 * 28,
460            "max_pixels": 1280 * 28 * 28,
461        },
462        disable_mm_preprocessor_cache=args.disable_mm_preprocessor_cache,
463    )
464
465    if modality == "image":
466        placeholder = "<|image_pad|>"
467    elif modality == "video":
468        placeholder = "<|video_pad|>"
469
470    prompt = ("<|im_start|>system\nYou are a helpful assistant.<|im_end|>\n"
471              f"<|im_start|>user\n<|vision_start|>{placeholder}<|vision_end|>"
472              f"{question}<|im_end|>\n"
473              "<|im_start|>assistant\n")
474    stop_token_ids = None
475    return llm, prompt, stop_token_ids
476
477
478model_example_map = {
479    "aria": run_aria,
480    "blip-2": run_blip2,
481    "chameleon": run_chameleon,
482    "fuyu": run_fuyu,
483    "glm4v": run_glm4v,
484    "h2ovl_chat": run_h2ovl,
485    "idefics3": run_idefics3,
486    "internvl_chat": run_internvl,
487    "llava": run_llava,
488    "llava-next": run_llava_next,
489    "llava-next-video": run_llava_next_video,
490    "llava-onevision": run_llava_onevision,
491    "mantis": run_mantis,
492    "minicpmv": run_minicpmv,
493    "mllama": run_mllama,
494    "molmo": run_molmo,
495    "NVLM_D": run_nvlm_d,
496    "paligemma": run_paligemma,
497    "paligemma2": run_paligemma2,
498    "phi3_v": run_phi3v,
499    "pixtral_hf": run_pixtral_hf,
500    "qwen_vl": run_qwen_vl,
501    "qwen2_vl": run_qwen2_vl,
502}
503
504
505def get_multi_modal_input(args):
506    """
507    return {
508        "data": image or video,
509        "question": question,
510    }
511    """
512    if args.modality == "image":
513        # Input image and question
514        image = ImageAsset("cherry_blossom") \
515            .pil_image.convert("RGB")
516        img_question = "What is the content of this image?"
517
518        return {
519            "data": image,
520            "question": img_question,
521        }
522
523    if args.modality == "video":
524        # Input video and question
525        video = VideoAsset(name="sample_demo_1.mp4",
526                           num_frames=args.num_frames).np_ndarrays
527        vid_question = "Why is this video funny?"
528
529        return {
530            "data": video,
531            "question": vid_question,
532        }
533
534    msg = f"Modality {args.modality} is not supported."
535    raise ValueError(msg)
536
537
538def apply_image_repeat(image_repeat_prob, num_prompts, data, prompt, modality):
539    """Repeats images with provided probability of "image_repeat_prob". 
540    Used to simulate hit/miss for the MM preprocessor cache.
541    """
542    assert (image_repeat_prob <= 1.0 and image_repeat_prob >= 0)
543    no_yes = [0, 1]
544    probs = [1.0 - image_repeat_prob, image_repeat_prob]
545
546    inputs = []
547    cur_image = data
548    for i in range(num_prompts):
549        if image_repeat_prob is not None:
550            res = random.choices(no_yes, probs)[0]
551            if res == 0:
552                # No repeat => Modify one pixel
553                cur_image = cur_image.copy()
554                new_val = (i // 256 // 256, i // 256, i % 256)
555                cur_image.putpixel((0, 0), new_val)
556
557        inputs.append({
558            "prompt": prompt,
559            "multi_modal_data": {
560                modality: cur_image
561            }
562        })
563
564    return inputs
565
566
567def main(args):
568    model = args.model_type
569    if model not in model_example_map:
570        raise ValueError(f"Model type {model} is not supported.")
571
572    modality = args.modality
573    mm_input = get_multi_modal_input(args)
574    data = mm_input["data"]
575    question = mm_input["question"]
576
577    llm, prompt, stop_token_ids = model_example_map[model](question, modality)
578
579    # We set temperature to 0.2 so that outputs can be different
580    # even when all prompts are identical when running batch inference.
581    sampling_params = SamplingParams(temperature=0.2,
582                                     max_tokens=64,
583                                     stop_token_ids=stop_token_ids)
584
585    assert args.num_prompts > 0
586    if args.num_prompts == 1:
587        # Single inference
588        inputs = {
589            "prompt": prompt,
590            "multi_modal_data": {
591                modality: data
592            },
593        }
594
595    else:
596        # Batch inference
597        if args.image_repeat_prob is not None:
598            # Repeat images with specified probability of "image_repeat_prob"
599            inputs = apply_image_repeat(args.image_repeat_prob,
600                                        args.num_prompts, data, prompt,
601                                        modality)
602        else:
603            # Use the same image for all prompts
604            inputs = [{
605                "prompt": prompt,
606                "multi_modal_data": {
607                    modality: data
608                },
609            } for _ in range(args.num_prompts)]
610
611    if args.time_generate:
612        import time
613        start_time = time.time()
614        outputs = llm.generate(inputs, sampling_params=sampling_params)
615        elapsed_time = time.time() - start_time
616        print("-- generate time = {}".format(elapsed_time))
617
618    else:
619        outputs = llm.generate(inputs, sampling_params=sampling_params)
620
621    for o in outputs:
622        generated_text = o.outputs[0].text
623        print(generated_text)
624
625
626if __name__ == "__main__":
627    parser = FlexibleArgumentParser(
628        description='Demo on using vLLM for offline inference with '
629        'vision language models for text generation')
630    parser.add_argument('--model-type',
631                        '-m',
632                        type=str,
633                        default="llava",
634                        choices=model_example_map.keys(),
635                        help='Huggingface "model_type".')
636    parser.add_argument('--num-prompts',
637                        type=int,
638                        default=4,
639                        help='Number of prompts to run.')
640    parser.add_argument('--modality',
641                        type=str,
642                        default="image",
643                        choices=['image', 'video'],
644                        help='Modality of the input.')
645    parser.add_argument('--num-frames',
646                        type=int,
647                        default=16,
648                        help='Number of frames to extract from the video.')
649
650    parser.add_argument(
651        '--image-repeat-prob',
652        type=float,
653        default=None,
654        help='Simulates the hit-ratio for multi-modal preprocessor cache'
655        ' (if enabled)')
656
657    parser.add_argument(
658        '--disable-mm-preprocessor-cache',
659        action='store_true',
660        help='If True, disables caching of multi-modal preprocessor/mapper.')
661
662    parser.add_argument(
663        '--time-generate',
664        action='store_true',
665        help='If True, then print the total generate() call time')
666
667    args = parser.parse_args()
668    main(args)