Offline Inference Vision Language

Offline Inference Vision Language#

Source vllm-project/vllm.

  1"""
  2This example shows how to use vLLM for running offline inference with
  3the correct prompt format on vision language models for text generation.
  4
  5For most models, the prompt format should follow corresponding examples
  6on HuggingFace model repository.
  7"""
  8import random
  9
 10from transformers import AutoTokenizer
 11
 12from vllm import LLM, SamplingParams
 13from vllm.assets.image import ImageAsset
 14from vllm.assets.video import VideoAsset
 15from vllm.utils import FlexibleArgumentParser
 16
 17# NOTE: The default `max_num_seqs` and `max_model_len` may result in OOM on
 18# lower-end GPUs.
 19# Unless specified, these settings have been tested to work on a single L4.
 20
 21
 22# Aria
 23def run_aria(question: str, modality: str):
 24    assert modality == "image"
 25    model_name = "rhymes-ai/Aria"
 26
 27    llm = LLM(model=model_name,
 28              tokenizer_mode="slow",
 29              trust_remote_code=True,
 30              dtype="bfloat16",
 31              mm_cache_preprocessor=args.mm_cache_preprocessor)
 32
 33    prompt = (f"<|im_start|>user\n<fim_prefix><|img|><fim_suffix>\n{question}"
 34              "<|im_end|>\n<|im_start|>assistant\n")
 35
 36    stop_token_ids = [93532, 93653, 944, 93421, 1019, 93653, 93519]
 37    return llm, prompt, stop_token_ids
 38
 39
 40# BLIP-2
 41def run_blip2(question: str, modality: str):
 42    assert modality == "image"
 43
 44    # BLIP-2 prompt format is inaccurate on HuggingFace model repository.
 45    # See https://huggingface.co/Salesforce/blip2-opt-2.7b/discussions/15#64ff02f3f8cf9e4f5b038262 #noqa
 46    prompt = f"Question: {question} Answer:"
 47    llm = LLM(model="Salesforce/blip2-opt-2.7b",
 48              mm_cache_preprocessor=args.mm_cache_preprocessor)
 49    stop_token_ids = None
 50    return llm, prompt, stop_token_ids
 51
 52
 53# Chameleon
 54def run_chameleon(question: str, modality: str):
 55    assert modality == "image"
 56
 57    prompt = f"{question}<image>"
 58    llm = LLM(model="facebook/chameleon-7b",
 59              max_model_len=4096,
 60              mm_cache_preprocessor=args.mm_cache_preprocessor)
 61    stop_token_ids = None
 62    return llm, prompt, stop_token_ids
 63
 64
 65# Fuyu
 66def run_fuyu(question: str, modality: str):
 67    assert modality == "image"
 68
 69    prompt = f"{question}\n"
 70    llm = LLM(model="adept/fuyu-8b",
 71              max_model_len=2048,
 72              max_num_seqs=2,
 73              mm_cache_preprocessor=args.mm_cache_preprocessor)
 74    stop_token_ids = None
 75    return llm, prompt, stop_token_ids
 76
 77
 78# GLM-4v
 79def run_glm4v(question: str, modality: str):
 80    assert modality == "image"
 81    model_name = "THUDM/glm-4v-9b"
 82
 83    llm = LLM(model=model_name,
 84              max_model_len=2048,
 85              max_num_seqs=2,
 86              trust_remote_code=True,
 87              enforce_eager=True,
 88              mm_cache_preprocessor=args.mm_cache_preprocessor)
 89    prompt = question
 90    stop_token_ids = [151329, 151336, 151338]
 91    return llm, prompt, stop_token_ids
 92
 93
 94# H2OVL-Mississippi
 95def run_h2ovl(question: str, modality: str):
 96    assert modality == "image"
 97
 98    model_name = "h2oai/h2ovl-mississippi-2b"
 99
100    llm = LLM(
101        model=model_name,
102        trust_remote_code=True,
103        max_model_len=8192,
104        mm_cache_preprocessor=args.mm_cache_preprocessor,
105    )
106
107    tokenizer = AutoTokenizer.from_pretrained(model_name,
108                                              trust_remote_code=True)
109    messages = [{'role': 'user', 'content': f"<image>\n{question}"}]
110    prompt = tokenizer.apply_chat_template(messages,
111                                           tokenize=False,
112                                           add_generation_prompt=True)
113
114    # Stop tokens for H2OVL-Mississippi
115    # https://huggingface.co/h2oai/h2ovl-mississippi-2b
116    stop_token_ids = [tokenizer.eos_token_id]
117    return llm, prompt, stop_token_ids
118
119
120# Idefics3-8B-Llama3
121def run_idefics3(question: str, modality: str):
122    assert modality == "image"
123    model_name = "HuggingFaceM4/Idefics3-8B-Llama3"
124
125    llm = LLM(
126        model=model_name,
127        max_model_len=8192,
128        max_num_seqs=2,
129        enforce_eager=True,
130        # if you are running out of memory, you can reduce the "longest_edge".
131        # see: https://huggingface.co/HuggingFaceM4/Idefics3-8B-Llama3#model-optimizations
132        mm_processor_kwargs={
133            "size": {
134                "longest_edge": 3 * 364
135            },
136        },
137        mm_cache_preprocessor=args.mm_cache_preprocessor,
138    )
139    prompt = (
140        f"<|begin_of_text|>User:<image>{question}<end_of_utterance>\nAssistant:"
141    )
142    stop_token_ids = None
143    return llm, prompt, stop_token_ids
144
145
146# InternVL
147def run_internvl(question: str, modality: str):
148    assert modality == "image"
149
150    model_name = "OpenGVLab/InternVL2-2B"
151
152    llm = LLM(
153        model=model_name,
154        trust_remote_code=True,
155        max_model_len=4096,
156        mm_cache_preprocessor=args.mm_cache_preprocessor,
157    )
158
159    tokenizer = AutoTokenizer.from_pretrained(model_name,
160                                              trust_remote_code=True)
161    messages = [{'role': 'user', 'content': f"<image>\n{question}"}]
162    prompt = tokenizer.apply_chat_template(messages,
163                                           tokenize=False,
164                                           add_generation_prompt=True)
165
166    # Stop tokens for InternVL
167    # models variants may have different stop tokens
168    # please refer to the model card for the correct "stop words":
169    # https://huggingface.co/OpenGVLab/InternVL2-2B/blob/main/conversation.py
170    stop_tokens = ["<|endoftext|>", "<|im_start|>", "<|im_end|>", "<|end|>"]
171    stop_token_ids = [tokenizer.convert_tokens_to_ids(i) for i in stop_tokens]
172    return llm, prompt, stop_token_ids
173
174
175# LLaVA-1.5
176def run_llava(question: str, modality: str):
177    assert modality == "image"
178
179    prompt = f"USER: <image>\n{question}\nASSISTANT:"
180
181    llm = LLM(model="llava-hf/llava-1.5-7b-hf",
182              max_model_len=4096,
183              mm_cache_preprocessor=args.mm_cache_preprocessor)
184    stop_token_ids = None
185    return llm, prompt, stop_token_ids
186
187
188# LLaVA-1.6/LLaVA-NeXT
189def run_llava_next(question: str, modality: str):
190    assert modality == "image"
191
192    prompt = f"[INST] <image>\n{question} [/INST]"
193    llm = LLM(model="llava-hf/llava-v1.6-mistral-7b-hf",
194              max_model_len=8192,
195              mm_cache_preprocessor=args.mm_cache_preprocessor)
196    stop_token_ids = None
197    return llm, prompt, stop_token_ids
198
199
200# LlaVA-NeXT-Video
201# Currently only support for video input
202def run_llava_next_video(question: str, modality: str):
203    assert modality == "video"
204
205    prompt = f"USER: <video>\n{question} ASSISTANT:"
206    llm = LLM(model="llava-hf/LLaVA-NeXT-Video-7B-hf",
207              max_model_len=8192,
208              mm_cache_preprocessor=args.mm_cache_preprocessor)
209    stop_token_ids = None
210    return llm, prompt, stop_token_ids
211
212
213# LLaVA-OneVision
214def run_llava_onevision(question: str, modality: str):
215
216    if modality == "video":
217        prompt = f"<|im_start|>user <video>\n{question}<|im_end|> \
218        <|im_start|>assistant\n"
219
220    elif modality == "image":
221        prompt = f"<|im_start|>user <image>\n{question}<|im_end|> \
222        <|im_start|>assistant\n"
223
224    llm = LLM(model="llava-hf/llava-onevision-qwen2-7b-ov-hf",
225              max_model_len=16384,
226              mm_cache_preprocessor=args.mm_cache_preprocessor)
227    stop_token_ids = None
228    return llm, prompt, stop_token_ids
229
230
231# Mantis
232def run_mantis(question: str, modality: str):
233    assert modality == "image"
234
235    llama3_template = '<|start_header_id|>user<|end_header_id|>\n\n{}<|eot_id|><|start_header_id|>assistant<|end_header_id|>\n\n'  # noqa: E501
236    prompt = llama3_template.format(f"{question}\n<image>")
237
238    llm = LLM(
239        model="TIGER-Lab/Mantis-8B-siglip-llama3",
240        max_model_len=4096,
241        hf_overrides={"architectures": ["MantisForConditionalGeneration"]},
242        mm_cache_preprocessor=args.mm_cache_preprocessor,
243    )
244    stop_token_ids = [128009]
245    return llm, prompt, stop_token_ids
246
247
248# MiniCPM-V
249def run_minicpmv(question: str, modality: str):
250    assert modality == "image"
251
252    # 2.0
253    # The official repo doesn't work yet, so we need to use a fork for now
254    # For more details, please see: See: https://github.com/vllm-project/vllm/pull/4087#issuecomment-2250397630 # noqa
255    # model_name = "HwwwH/MiniCPM-V-2"
256
257    # 2.5
258    # model_name = "openbmb/MiniCPM-Llama3-V-2_5"
259
260    #2.6
261    model_name = "openbmb/MiniCPM-V-2_6"
262    tokenizer = AutoTokenizer.from_pretrained(model_name,
263                                              trust_remote_code=True)
264    llm = LLM(
265        model=model_name,
266        max_model_len=4096,
267        max_num_seqs=2,
268        trust_remote_code=True,
269        mm_cache_preprocessor=args.mm_cache_preprocessor,
270    )
271    # NOTE The stop_token_ids are different for various versions of MiniCPM-V
272    # 2.0
273    # stop_token_ids = [tokenizer.eos_id]
274
275    # 2.5
276    # stop_token_ids = [tokenizer.eos_id, tokenizer.eot_id]
277
278    # 2.6
279    stop_tokens = ['<|im_end|>', '<|endoftext|>']
280    stop_token_ids = [tokenizer.convert_tokens_to_ids(i) for i in stop_tokens]
281
282    messages = [{
283        'role': 'user',
284        'content': f'(<image>./</image>)\n{question}'
285    }]
286    prompt = tokenizer.apply_chat_template(messages,
287                                           tokenize=False,
288                                           add_generation_prompt=True)
289    return llm, prompt, stop_token_ids
290
291
292# LLama 3.2
293def run_mllama(question: str, modality: str):
294    assert modality == "image"
295
296    model_name = "meta-llama/Llama-3.2-11B-Vision-Instruct"
297
298    # Note: The default setting of max_num_seqs (256) and
299    # max_model_len (131072) for this model may cause OOM.
300    # You may lower either to run this example on lower-end GPUs.
301
302    # The configuration below has been confirmed to launch on a single L40 GPU.
303    llm = LLM(
304        model=model_name,
305        max_model_len=4096,
306        max_num_seqs=16,
307        enforce_eager=True,
308        mm_cache_preprocessor=args.mm_cache_preprocessor,
309    )
310
311    prompt = f"<|image|><|begin_of_text|>{question}"
312    stop_token_ids = None
313    return llm, prompt, stop_token_ids
314
315
316# Molmo
317def run_molmo(question, modality):
318    assert modality == "image"
319
320    model_name = "allenai/Molmo-7B-D-0924"
321
322    llm = LLM(
323        model=model_name,
324        trust_remote_code=True,
325        dtype="bfloat16",
326        mm_cache_preprocessor=args.mm_cache_preprocessor,
327    )
328
329    prompt = question
330    stop_token_ids = None
331    return llm, prompt, stop_token_ids
332
333
334# NVLM-D
335def run_nvlm_d(question: str, modality: str):
336    assert modality == "image"
337
338    model_name = "nvidia/NVLM-D-72B"
339
340    # Adjust this as necessary to fit in GPU
341    llm = LLM(
342        model=model_name,
343        trust_remote_code=True,
344        max_model_len=4096,
345        tensor_parallel_size=4,
346        mm_cache_preprocessor=args.mm_cache_preprocessor,
347    )
348
349    tokenizer = AutoTokenizer.from_pretrained(model_name,
350                                              trust_remote_code=True)
351    messages = [{'role': 'user', 'content': f"<image>\n{question}"}]
352    prompt = tokenizer.apply_chat_template(messages,
353                                           tokenize=False,
354                                           add_generation_prompt=True)
355    stop_token_ids = None
356    return llm, prompt, stop_token_ids
357
358
359# PaliGemma
360def run_paligemma(question: str, modality: str):
361    assert modality == "image"
362
363    # PaliGemma has special prompt format for VQA
364    prompt = "caption en"
365    llm = LLM(model="google/paligemma-3b-mix-224",
366              mm_cache_preprocessor=args.mm_cache_preprocessor)
367    stop_token_ids = None
368    return llm, prompt, stop_token_ids
369
370
371# PaliGemma 2
372def run_paligemma2(question: str, modality: str):
373    assert modality == "image"
374
375    # PaliGemma 2 has special prompt format for VQA
376    prompt = "caption en"
377    llm = LLM(model="google/paligemma2-3b-ft-docci-448",
378              mm_cache_preprocessor=args.mm_cache_preprocessor)
379    stop_token_ids = None
380    return llm, prompt, stop_token_ids
381
382
383# Phi-3-Vision
384def run_phi3v(question: str, modality: str):
385    assert modality == "image"
386
387    prompt = f"<|user|>\n<|image_1|>\n{question}<|end|>\n<|assistant|>\n"
388
389    # num_crops is an override kwarg to the multimodal image processor;
390    # For some models, e.g., Phi-3.5-vision-instruct, it is recommended
391    # to use 16 for single frame scenarios, and 4 for multi-frame.
392    #
393    # Generally speaking, a larger value for num_crops results in more
394    # tokens per image instance, because it may scale the image more in
395    # the image preprocessing. Some references in the model docs and the
396    # formula for image tokens after the preprocessing
397    # transform can be found below.
398    #
399    # https://huggingface.co/microsoft/Phi-3.5-vision-instruct#loading-the-model-locally
400    # https://huggingface.co/microsoft/Phi-3.5-vision-instruct/blob/main/processing_phi3_v.py#L194
401    llm = LLM(
402        model="microsoft/Phi-3.5-vision-instruct",
403        trust_remote_code=True,
404        max_model_len=4096,
405        max_num_seqs=2,
406        # Note - mm_processor_kwargs can also be passed to generate/chat calls
407        mm_processor_kwargs={"num_crops": 16},
408        mm_cache_preprocessor=args.mm_cache_preprocessor,
409    )
410    stop_token_ids = None
411    return llm, prompt, stop_token_ids
412
413
414# Pixtral HF-format
415def run_pixtral_hf(question: str, modality: str):
416    assert modality == "image"
417
418    model_name = "mistral-community/pixtral-12b"
419
420    llm = LLM(
421        model=model_name,
422        max_model_len=8192,
423        mm_cache_preprocessor=args.mm_cache_preprocessor,
424    )
425
426    prompt = f"<s>[INST]{question}\n[IMG][/INST]"
427    stop_token_ids = None
428    return llm, prompt, stop_token_ids
429
430
431# Qwen
432def run_qwen_vl(question: str, modality: str):
433    assert modality == "image"
434
435    llm = LLM(
436        model="Qwen/Qwen-VL",
437        trust_remote_code=True,
438        max_model_len=1024,
439        max_num_seqs=2,
440        mm_cache_preprocessor=args.mm_cache_preprocessor,
441    )
442
443    prompt = f"{question}Picture 1: <img></img>\n"
444    stop_token_ids = None
445    return llm, prompt, stop_token_ids
446
447
448# Qwen2-VL
449def run_qwen2_vl(question: str, modality: str):
450    assert modality == "image"
451
452    model_name = "Qwen/Qwen2-VL-7B-Instruct"
453
454    llm = LLM(
455        model=model_name,
456        max_model_len=4096,
457        max_num_seqs=5,
458        # Note - mm_processor_kwargs can also be passed to generate/chat calls
459        mm_processor_kwargs={
460            "min_pixels": 28 * 28,
461            "max_pixels": 1280 * 28 * 28,
462        },
463        mm_cache_preprocessor=args.mm_cache_preprocessor,
464    )
465
466    prompt = ("<|im_start|>system\nYou are a helpful assistant.<|im_end|>\n"
467              "<|im_start|>user\n<|vision_start|><|image_pad|><|vision_end|>"
468              f"{question}<|im_end|>\n"
469              "<|im_start|>assistant\n")
470    stop_token_ids = None
471    return llm, prompt, stop_token_ids
472
473
474model_example_map = {
475    "aria": run_aria,
476    "blip-2": run_blip2,
477    "chameleon": run_chameleon,
478    "fuyu": run_fuyu,
479    "glm4v": run_glm4v,
480    "h2ovl_chat": run_h2ovl,
481    "idefics3": run_idefics3,
482    "internvl_chat": run_internvl,
483    "llava": run_llava,
484    "llava-next": run_llava_next,
485    "llava-next-video": run_llava_next_video,
486    "llava-onevision": run_llava_onevision,
487    "mantis": run_mantis,
488    "minicpmv": run_minicpmv,
489    "mllama": run_mllama,
490    "molmo": run_molmo,
491    "NVLM_D": run_nvlm_d,
492    "paligemma": run_paligemma,
493    "paligemma2": run_paligemma2,
494    "phi3_v": run_phi3v,
495    "pixtral_hf": run_pixtral_hf,
496    "qwen_vl": run_qwen_vl,
497    "qwen2_vl": run_qwen2_vl,
498}
499
500
501def get_multi_modal_input(args):
502    """
503    return {
504        "data": image or video,
505        "question": question,
506    }
507    """
508    if args.modality == "image":
509        # Input image and question
510        image = ImageAsset("cherry_blossom") \
511            .pil_image.convert("RGB")
512        img_question = "What is the content of this image?"
513
514        return {
515            "data": image,
516            "question": img_question,
517        }
518
519    if args.modality == "video":
520        # Input video and question
521        video = VideoAsset(name="sample_demo_1.mp4",
522                           num_frames=args.num_frames).np_ndarrays
523        vid_question = "Why is this video funny?"
524
525        return {
526            "data": video,
527            "question": vid_question,
528        }
529
530    msg = f"Modality {args.modality} is not supported."
531    raise ValueError(msg)
532
533
534def apply_image_repeat(image_repeat_prob, num_prompts, data, prompt, modality):
535    """Repeats images with provided probability of "image_repeat_prob". 
536    Used to simulate hit/miss for the MM preprocessor cache.
537    """
538    assert (image_repeat_prob <= 1.0 and image_repeat_prob >= 0)
539    no_yes = [0, 1]
540    probs = [1.0 - image_repeat_prob, image_repeat_prob]
541
542    inputs = []
543    cur_image = data
544    for i in range(num_prompts):
545        if image_repeat_prob is not None:
546            res = random.choices(no_yes, probs)[0]
547            if res == 0:
548                # No repeat => Modify one pixel
549                cur_image = cur_image.copy()
550                new_val = (i // 256 // 256, i // 256, i % 256)
551                cur_image.putpixel((0, 0), new_val)
552
553        inputs.append({
554            "prompt": prompt,
555            "multi_modal_data": {
556                modality: cur_image
557            }
558        })
559
560    return inputs
561
562
563def main(args):
564    model = args.model_type
565    if model not in model_example_map:
566        raise ValueError(f"Model type {model} is not supported.")
567
568    modality = args.modality
569    mm_input = get_multi_modal_input(args)
570    data = mm_input["data"]
571    question = mm_input["question"]
572
573    llm, prompt, stop_token_ids = model_example_map[model](question, modality)
574
575    # We set temperature to 0.2 so that outputs can be different
576    # even when all prompts are identical when running batch inference.
577    sampling_params = SamplingParams(temperature=0.2,
578                                     max_tokens=64,
579                                     stop_token_ids=stop_token_ids)
580
581    assert args.num_prompts > 0
582    if args.num_prompts == 1:
583        # Single inference
584        inputs = {
585            "prompt": prompt,
586            "multi_modal_data": {
587                modality: data
588            },
589        }
590
591    else:
592        # Batch inference
593        if args.image_repeat_prob is not None:
594            # Repeat images with specified probability of "image_repeat_prob"
595            inputs = apply_image_repeat(args.image_repeat_prob,
596                                        args.num_prompts, data, prompt,
597                                        modality)
598        else:
599            # Use the same image for all prompts
600            inputs = [{
601                "prompt": prompt,
602                "multi_modal_data": {
603                    modality: data
604                },
605            } for _ in range(args.num_prompts)]
606
607    if args.time_generate:
608        import time
609        start_time = time.time()
610        outputs = llm.generate(inputs, sampling_params=sampling_params)
611        elapsed_time = time.time() - start_time
612        print("-- generate time = {}".format(elapsed_time))
613
614    else:
615        outputs = llm.generate(inputs, sampling_params=sampling_params)
616
617    for o in outputs:
618        generated_text = o.outputs[0].text
619        print(generated_text)
620
621
622if __name__ == "__main__":
623    parser = FlexibleArgumentParser(
624        description='Demo on using vLLM for offline inference with '
625        'vision language models for text generation')
626    parser.add_argument('--model-type',
627                        '-m',
628                        type=str,
629                        default="llava",
630                        choices=model_example_map.keys(),
631                        help='Huggingface "model_type".')
632    parser.add_argument('--num-prompts',
633                        type=int,
634                        default=4,
635                        help='Number of prompts to run.')
636    parser.add_argument('--modality',
637                        type=str,
638                        default="image",
639                        choices=['image', 'video'],
640                        help='Modality of the input.')
641    parser.add_argument('--num-frames',
642                        type=int,
643                        default=16,
644                        help='Number of frames to extract from the video.')
645
646    parser.add_argument(
647        '--image-repeat-prob',
648        type=float,
649        default=None,
650        help='Simulates the hit-ratio for multi-modal preprocessor cache'
651        ' (if enabled)')
652
653    parser.add_argument(
654        '--mm-cache-preprocessor',
655        action='store_true',
656        help='If True, enable caching of multi-modal preprocessor/mapper.')
657
658    parser.add_argument(
659        '--time-generate',
660        action='store_true',
661        help='If True, then print the total generate() call time')
662
663    args = parser.parse_args()
664    main(args)