Offline Inference Vision Language#

Source vllm-project/vllm.

  1"""
  2This example shows how to use vLLM for running offline inference 
  3with the correct prompt format on vision language models.
  4
  5For most models, the prompt format should follow corresponding examples
  6on HuggingFace model repository.
  7"""
  8from transformers import AutoTokenizer
  9
 10from vllm import LLM, SamplingParams
 11from vllm.assets.image import ImageAsset
 12from vllm.assets.video import VideoAsset
 13from vllm.utils import FlexibleArgumentParser
 14
 15
 16# LLaVA-1.5
 17def run_llava(question, modality):
 18    assert modality == "image"
 19
 20    prompt = f"USER: <image>\n{question}\nASSISTANT:"
 21
 22    llm = LLM(model="llava-hf/llava-1.5-7b-hf")
 23    stop_token_ids = None
 24    return llm, prompt, stop_token_ids
 25
 26
 27# LLaVA-1.6/LLaVA-NeXT
 28def run_llava_next(question, modality):
 29    assert modality == "image"
 30
 31    prompt = f"[INST] <image>\n{question} [/INST]"
 32    llm = LLM(model="llava-hf/llava-v1.6-mistral-7b-hf", max_model_len=8192)
 33    stop_token_ids = None
 34    return llm, prompt, stop_token_ids
 35
 36
 37# LlaVA-NeXT-Video
 38# Currently only support for video input
 39def run_llava_next_video(question, modality):
 40    assert modality == "video"
 41
 42    prompt = f"USER: <video>\n{question} ASSISTANT:"
 43    llm = LLM(model="llava-hf/LLaVA-NeXT-Video-7B-hf", max_model_len=8192)
 44    stop_token_ids = None
 45    return llm, prompt, stop_token_ids
 46
 47
 48# LLaVA-OneVision
 49def run_llava_onevision(question, modality):
 50
 51    if modality == "video":
 52        prompt = f"<|im_start|>user <video>\n{question}<|im_end|> \
 53        <|im_start|>assistant\n"
 54
 55    elif modality == "image":
 56        prompt = f"<|im_start|>user <image>\n{question}<|im_end|> \
 57        <|im_start|>assistant\n"
 58
 59    llm = LLM(model="llava-hf/llava-onevision-qwen2-7b-ov-hf",
 60              max_model_len=32768)
 61    stop_token_ids = None
 62    return llm, prompt, stop_token_ids
 63
 64
 65# Fuyu
 66def run_fuyu(question, modality):
 67    assert modality == "image"
 68
 69    prompt = f"{question}\n"
 70    llm = LLM(model="adept/fuyu-8b")
 71    stop_token_ids = None
 72    return llm, prompt, stop_token_ids
 73
 74
 75# Phi-3-Vision
 76def run_phi3v(question, modality):
 77    assert modality == "image"
 78
 79    prompt = f"<|user|>\n<|image_1|>\n{question}<|end|>\n<|assistant|>\n"  # noqa: E501
 80    # Note: The default setting of max_num_seqs (256) and
 81    # max_model_len (128k) for this model may cause OOM.
 82    # You may lower either to run this example on lower-end GPUs.
 83
 84    # In this example, we override max_num_seqs to 5 while
 85    # keeping the original context length of 128k.
 86
 87    # num_crops is an override kwarg to the multimodal image processor;
 88    # For some models, e.g., Phi-3.5-vision-instruct, it is recommended
 89    # to use 16 for single frame scenarios, and 4 for multi-frame.
 90    #
 91    # Generally speaking, a larger value for num_crops results in more
 92    # tokens per image instance, because it may scale the image more in
 93    # the image preprocessing. Some references in the model docs and the
 94    # formula for image tokens after the preprocessing
 95    # transform can be found below.
 96    #
 97    # https://huggingface.co/microsoft/Phi-3.5-vision-instruct#loading-the-model-locally
 98    # https://huggingface.co/microsoft/Phi-3.5-vision-instruct/blob/main/processing_phi3_v.py#L194
 99    llm = LLM(
100        model="microsoft/Phi-3-vision-128k-instruct",
101        trust_remote_code=True,
102        max_num_seqs=5,
103        mm_processor_kwargs={"num_crops": 16},
104    )
105    stop_token_ids = None
106    return llm, prompt, stop_token_ids
107
108
109# PaliGemma
110def run_paligemma(question, modality):
111    assert modality == "image"
112
113    # PaliGemma has special prompt format for VQA
114    prompt = "caption en"
115    llm = LLM(model="google/paligemma-3b-mix-224")
116    stop_token_ids = None
117    return llm, prompt, stop_token_ids
118
119
120# Chameleon
121def run_chameleon(question, modality):
122    assert modality == "image"
123
124    prompt = f"{question}<image>"
125    llm = LLM(model="facebook/chameleon-7b")
126    stop_token_ids = None
127    return llm, prompt, stop_token_ids
128
129
130# MiniCPM-V
131def run_minicpmv(question, modality):
132    assert modality == "image"
133
134    # 2.0
135    # The official repo doesn't work yet, so we need to use a fork for now
136    # For more details, please see: See: https://github.com/vllm-project/vllm/pull/4087#issuecomment-2250397630 # noqa
137    # model_name = "HwwwH/MiniCPM-V-2"
138
139    # 2.5
140    # model_name = "openbmb/MiniCPM-Llama3-V-2_5"
141
142    #2.6
143    model_name = "openbmb/MiniCPM-V-2_6"
144    tokenizer = AutoTokenizer.from_pretrained(model_name,
145                                              trust_remote_code=True)
146    llm = LLM(
147        model=model_name,
148        trust_remote_code=True,
149    )
150    # NOTE The stop_token_ids are different for various versions of MiniCPM-V
151    # 2.0
152    # stop_token_ids = [tokenizer.eos_id]
153
154    # 2.5
155    # stop_token_ids = [tokenizer.eos_id, tokenizer.eot_id]
156
157    # 2.6
158    stop_tokens = ['<|im_end|>', '<|endoftext|>']
159    stop_token_ids = [tokenizer.convert_tokens_to_ids(i) for i in stop_tokens]
160
161    messages = [{
162        'role': 'user',
163        'content': f'(<image>./</image>)\n{question}'
164    }]
165    prompt = tokenizer.apply_chat_template(messages,
166                                           tokenize=False,
167                                           add_generation_prompt=True)
168    return llm, prompt, stop_token_ids
169
170
171# InternVL
172def run_internvl(question, modality):
173    assert modality == "image"
174
175    model_name = "OpenGVLab/InternVL2-2B"
176
177    llm = LLM(
178        model=model_name,
179        trust_remote_code=True,
180        max_num_seqs=5,
181    )
182
183    tokenizer = AutoTokenizer.from_pretrained(model_name,
184                                              trust_remote_code=True)
185    messages = [{'role': 'user', 'content': f"<image>\n{question}"}]
186    prompt = tokenizer.apply_chat_template(messages,
187                                           tokenize=False,
188                                           add_generation_prompt=True)
189
190    # Stop tokens for InternVL
191    # models variants may have different stop tokens
192    # please refer to the model card for the correct "stop words":
193    # https://huggingface.co/OpenGVLab/InternVL2-2B#service
194    stop_tokens = ["<|endoftext|>", "<|im_start|>", "<|im_end|>", "<|end|>"]
195    stop_token_ids = [tokenizer.convert_tokens_to_ids(i) for i in stop_tokens]
196    return llm, prompt, stop_token_ids
197
198
199# BLIP-2
200def run_blip2(question, modality):
201    assert modality == "image"
202
203    # BLIP-2 prompt format is inaccurate on HuggingFace model repository.
204    # See https://huggingface.co/Salesforce/blip2-opt-2.7b/discussions/15#64ff02f3f8cf9e4f5b038262 #noqa
205    prompt = f"Question: {question} Answer:"
206    llm = LLM(model="Salesforce/blip2-opt-2.7b")
207    stop_token_ids = None
208    return llm, prompt, stop_token_ids
209
210
211# Qwen
212def run_qwen_vl(question, modality):
213    assert modality == "image"
214
215    llm = LLM(
216        model="Qwen/Qwen-VL",
217        trust_remote_code=True,
218        max_num_seqs=5,
219    )
220
221    prompt = f"{question}Picture 1: <img></img>\n"
222    stop_token_ids = None
223    return llm, prompt, stop_token_ids
224
225
226# Qwen2-VL
227def run_qwen2_vl(question, modality):
228    assert modality == "image"
229
230    model_name = "Qwen/Qwen2-VL-7B-Instruct"
231
232    llm = LLM(
233        model=model_name,
234        max_num_seqs=5,
235    )
236
237    prompt = ("<|im_start|>system\nYou are a helpful assistant.<|im_end|>\n"
238              "<|im_start|>user\n<|vision_start|><|image_pad|><|vision_end|>"
239              f"{question}<|im_end|>\n"
240              "<|im_start|>assistant\n")
241    stop_token_ids = None
242    return llm, prompt, stop_token_ids
243
244
245# LLama
246def run_mllama(question, modality):
247    assert modality == "image"
248
249    model_name = "meta-llama/Llama-3.2-11B-Vision-Instruct"
250
251    # Note: The default setting of max_num_seqs (256) and
252    # max_model_len (131072) for this model may cause OOM.
253    # You may lower either to run this example on lower-end GPUs.
254
255    # The configuration below has been confirmed to launch on a
256    # single H100 GPU.
257    llm = LLM(
258        model=model_name,
259        max_num_seqs=16,
260        enforce_eager=True,
261    )
262
263    prompt = f"<|image|><|begin_of_text|>{question}"
264    stop_token_ids = None
265    return llm, prompt, stop_token_ids
266
267
268model_example_map = {
269    "llava": run_llava,
270    "llava-next": run_llava_next,
271    "llava-next-video": run_llava_next_video,
272    "llava-onevision": run_llava_onevision,
273    "fuyu": run_fuyu,
274    "phi3_v": run_phi3v,
275    "paligemma": run_paligemma,
276    "chameleon": run_chameleon,
277    "minicpmv": run_minicpmv,
278    "blip-2": run_blip2,
279    "internvl_chat": run_internvl,
280    "qwen_vl": run_qwen_vl,
281    "qwen2_vl": run_qwen2_vl,
282    "mllama": run_mllama,
283}
284
285
286def get_multi_modal_input(args):
287    """
288    return {
289        "data": image or video,
290        "question": question,
291    }
292    """
293    if args.modality == "image":
294        # Input image and question
295        image = ImageAsset("cherry_blossom") \
296            .pil_image.convert("RGB")
297        img_question = "What is the content of this image?"
298
299        return {
300            "data": image,
301            "question": img_question,
302        }
303
304    if args.modality == "video":
305        # Input video and question
306        video = VideoAsset(name="sample_demo_1.mp4",
307                           num_frames=args.num_frames).np_ndarrays
308        vid_question = "Why is this video funny?"
309
310        return {
311            "data": video,
312            "question": vid_question,
313        }
314
315    msg = f"Modality {args.modality} is not supported."
316    raise ValueError(msg)
317
318
319def main(args):
320    model = args.model_type
321    if model not in model_example_map:
322        raise ValueError(f"Model type {model} is not supported.")
323
324    modality = args.modality
325    mm_input = get_multi_modal_input(args)
326    data = mm_input["data"]
327    question = mm_input["question"]
328
329    llm, prompt, stop_token_ids = model_example_map[model](question, modality)
330
331    # We set temperature to 0.2 so that outputs can be different
332    # even when all prompts are identical when running batch inference.
333    sampling_params = SamplingParams(temperature=0.2,
334                                     max_tokens=64,
335                                     stop_token_ids=stop_token_ids)
336
337    assert args.num_prompts > 0
338    if args.num_prompts == 1:
339        # Single inference
340        inputs = {
341            "prompt": prompt,
342            "multi_modal_data": {
343                modality: data
344            },
345        }
346
347    else:
348        # Batch inference
349        inputs = [{
350            "prompt": prompt,
351            "multi_modal_data": {
352                modality: data
353            },
354        } for _ in range(args.num_prompts)]
355
356    outputs = llm.generate(inputs, sampling_params=sampling_params)
357
358    for o in outputs:
359        generated_text = o.outputs[0].text
360        print(generated_text)
361
362
363if __name__ == "__main__":
364    parser = FlexibleArgumentParser(
365        description='Demo on using vLLM for offline inference with '
366        'vision language models')
367    parser.add_argument('--model-type',
368                        '-m',
369                        type=str,
370                        default="llava",
371                        choices=model_example_map.keys(),
372                        help='Huggingface "model_type".')
373    parser.add_argument('--num-prompts',
374                        type=int,
375                        default=4,
376                        help='Number of prompts to run.')
377    parser.add_argument('--modality',
378                        type=str,
379                        default="image",
380                        choices=['image', 'video'],
381                        help='Modality of the input.')
382    parser.add_argument('--num-frames',
383                        type=int,
384                        default=16,
385                        help='Number of frames to extract from the video.')
386    args = parser.parse_args()
387    main(args)