Offline Inference Vision Language#

Source vllm-project/vllm.

  1"""
  2This example shows how to use vLLM for running offline inference 
  3with the correct prompt format on vision language models.
  4
  5For most models, the prompt format should follow corresponding examples
  6on HuggingFace model repository.
  7"""
  8from transformers import AutoTokenizer
  9
 10from vllm import LLM, SamplingParams
 11from vllm.assets.image import ImageAsset
 12from vllm.assets.video import VideoAsset
 13from vllm.utils import FlexibleArgumentParser
 14
 15
 16# LLaVA-1.5
 17def run_llava(question):
 18
 19    prompt = f"USER: <image>\n{question}\nASSISTANT:"
 20
 21    llm = LLM(model="llava-hf/llava-1.5-7b-hf")
 22    stop_token_ids = None
 23    return llm, prompt, stop_token_ids
 24
 25
 26# LLaVA-1.6/LLaVA-NeXT
 27def run_llava_next(question):
 28
 29    prompt = f"[INST] <image>\n{question} [/INST]"
 30    llm = LLM(model="llava-hf/llava-v1.6-mistral-7b-hf", max_model_len=8192)
 31    stop_token_ids = None
 32    return llm, prompt, stop_token_ids
 33
 34
 35# LlaVA-NeXT-Video
 36# Currently only support for video input
 37def run_llava_next_video(question):
 38    prompt = f"USER: <video>\n{question} ASSISTANT:"
 39    llm = LLM(model="llava-hf/LLaVA-NeXT-Video-7B-hf", max_model_len=8192)
 40    stop_token_ids = None
 41    return llm, prompt, stop_token_ids
 42
 43
 44# Fuyu
 45def run_fuyu(question):
 46
 47    prompt = f"{question}\n"
 48    llm = LLM(model="adept/fuyu-8b")
 49    stop_token_ids = None
 50    return llm, prompt, stop_token_ids
 51
 52
 53# Phi-3-Vision
 54def run_phi3v(question):
 55
 56    prompt = f"<|user|>\n<|image_1|>\n{question}<|end|>\n<|assistant|>\n"  # noqa: E501
 57    # Note: The default setting of max_num_seqs (256) and
 58    # max_model_len (128k) for this model may cause OOM.
 59    # You may lower either to run this example on lower-end GPUs.
 60
 61    # In this example, we override max_num_seqs to 5 while
 62    # keeping the original context length of 128k.
 63    llm = LLM(
 64        model="microsoft/Phi-3-vision-128k-instruct",
 65        trust_remote_code=True,
 66        max_num_seqs=5,
 67    )
 68    stop_token_ids = None
 69    return llm, prompt, stop_token_ids
 70
 71
 72# PaliGemma
 73def run_paligemma(question):
 74
 75    # PaliGemma has special prompt format for VQA
 76    prompt = "caption en"
 77    llm = LLM(model="google/paligemma-3b-mix-224")
 78    stop_token_ids = None
 79    return llm, prompt, stop_token_ids
 80
 81
 82# Chameleon
 83def run_chameleon(question):
 84
 85    prompt = f"{question}<image>"
 86    llm = LLM(model="facebook/chameleon-7b")
 87    stop_token_ids = None
 88    return llm, prompt, stop_token_ids
 89
 90
 91# MiniCPM-V
 92def run_minicpmv(question):
 93
 94    # 2.0
 95    # The official repo doesn't work yet, so we need to use a fork for now
 96    # For more details, please see: See: https://github.com/vllm-project/vllm/pull/4087#issuecomment-2250397630 # noqa
 97    # model_name = "HwwwH/MiniCPM-V-2"
 98
 99    # 2.5
100    # model_name = "openbmb/MiniCPM-Llama3-V-2_5"
101
102    #2.6
103    model_name = "openbmb/MiniCPM-V-2_6"
104    tokenizer = AutoTokenizer.from_pretrained(model_name,
105                                              trust_remote_code=True)
106    llm = LLM(
107        model=model_name,
108        trust_remote_code=True,
109    )
110    # NOTE The stop_token_ids are different for various versions of MiniCPM-V
111    # 2.0
112    # stop_token_ids = [tokenizer.eos_id]
113
114    # 2.5
115    # stop_token_ids = [tokenizer.eos_id, tokenizer.eot_id]
116
117    # 2.6
118    stop_tokens = ['<|im_end|>', '<|endoftext|>']
119    stop_token_ids = [tokenizer.convert_tokens_to_ids(i) for i in stop_tokens]
120
121    messages = [{
122        'role': 'user',
123        'content': f'(<image>./</image>)\n{question}'
124    }]
125    prompt = tokenizer.apply_chat_template(messages,
126                                           tokenize=False,
127                                           add_generation_prompt=True)
128    return llm, prompt, stop_token_ids
129
130
131# InternVL
132def run_internvl(question):
133    model_name = "OpenGVLab/InternVL2-2B"
134
135    llm = LLM(
136        model=model_name,
137        trust_remote_code=True,
138        max_num_seqs=5,
139    )
140
141    tokenizer = AutoTokenizer.from_pretrained(model_name,
142                                              trust_remote_code=True)
143    messages = [{'role': 'user', 'content': f"<image>\n{question}"}]
144    prompt = tokenizer.apply_chat_template(messages,
145                                           tokenize=False,
146                                           add_generation_prompt=True)
147
148    # Stop tokens for InternVL
149    # models variants may have different stop tokens
150    # please refer to the model card for the correct "stop words":
151    # https://huggingface.co/OpenGVLab/InternVL2-2B#service
152    stop_tokens = ["<|endoftext|>", "<|im_start|>", "<|im_end|>", "<|end|>"]
153    stop_token_ids = [tokenizer.convert_tokens_to_ids(i) for i in stop_tokens]
154    return llm, prompt, stop_token_ids
155
156
157# BLIP-2
158def run_blip2(question):
159
160    # BLIP-2 prompt format is inaccurate on HuggingFace model repository.
161    # See https://huggingface.co/Salesforce/blip2-opt-2.7b/discussions/15#64ff02f3f8cf9e4f5b038262 #noqa
162    prompt = f"Question: {question} Answer:"
163    llm = LLM(model="Salesforce/blip2-opt-2.7b")
164    stop_token_ids = None
165    return llm, prompt, stop_token_ids
166
167
168# Qwen
169def run_qwen_vl(question):
170
171    llm = LLM(
172        model="Qwen/Qwen-VL",
173        trust_remote_code=True,
174        max_num_seqs=5,
175    )
176
177    prompt = f"{question}Picture 1: <img></img>\n"
178    stop_token_ids = None
179    return llm, prompt, stop_token_ids
180
181
182# Qwen2-VL
183def run_qwen2_vl(question):
184    model_name = "Qwen/Qwen2-VL-7B-Instruct"
185
186    llm = LLM(
187        model=model_name,
188        max_num_seqs=5,
189    )
190
191    prompt = ("<|im_start|>system\nYou are a helpful assistant.<|im_end|>\n"
192              "<|im_start|>user\n<|vision_start|><|image_pad|><|vision_end|>"
193              f"{question}<|im_end|>\n"
194              "<|im_start|>assistant\n")
195    stop_token_ids = None
196    return llm, prompt, stop_token_ids
197
198
199model_example_map = {
200    "llava": run_llava,
201    "llava-next": run_llava_next,
202    "llava-next-video": run_llava_next_video,
203    "fuyu": run_fuyu,
204    "phi3_v": run_phi3v,
205    "paligemma": run_paligemma,
206    "chameleon": run_chameleon,
207    "minicpmv": run_minicpmv,
208    "blip-2": run_blip2,
209    "internvl_chat": run_internvl,
210    "qwen_vl": run_qwen_vl,
211    "qwen2_vl": run_qwen2_vl,
212}
213
214
215def get_multi_modal_input(args):
216    """
217    return {
218        "data": image or video,
219        "question": question,
220    }
221    """
222    if args.modality == "image":
223        # Input image and question
224        image = ImageAsset("cherry_blossom") \
225            .pil_image.convert("RGB")
226        img_question = "What is the content of this image?"
227
228        return {
229            "data": image,
230            "question": img_question,
231        }
232
233    if args.modality == "video":
234        # Input video and question
235        video = VideoAsset(name="sample_demo_1.mp4",
236                           num_frames=args.num_frames).np_ndarrays
237        vid_question = "Why is this video funny?"
238
239        return {
240            "data": video,
241            "question": vid_question,
242        }
243
244    msg = f"Modality {args.modality} is not supported."
245    raise ValueError(msg)
246
247
248def main(args):
249    model = args.model_type
250    if model not in model_example_map:
251        raise ValueError(f"Model type {model} is not supported.")
252
253    modality = args.modality
254    mm_input = get_multi_modal_input(args)
255    data = mm_input["data"]
256    question = mm_input["question"]
257
258    llm, prompt, stop_token_ids = model_example_map[model](question)
259
260    # We set temperature to 0.2 so that outputs can be different
261    # even when all prompts are identical when running batch inference.
262    sampling_params = SamplingParams(temperature=0.2,
263                                     max_tokens=64,
264                                     stop_token_ids=stop_token_ids)
265
266    assert args.num_prompts > 0
267    if args.num_prompts == 1:
268        # Single inference
269        inputs = {
270            "prompt": prompt,
271            "multi_modal_data": {
272                modality: data
273            },
274        }
275
276    else:
277        # Batch inference
278        inputs = [{
279            "prompt": prompt,
280            "multi_modal_data": {
281                modality: data
282            },
283        } for _ in range(args.num_prompts)]
284
285    outputs = llm.generate(inputs, sampling_params=sampling_params)
286
287    for o in outputs:
288        generated_text = o.outputs[0].text
289        print(generated_text)
290
291
292if __name__ == "__main__":
293    parser = FlexibleArgumentParser(
294        description='Demo on using vLLM for offline inference with '
295        'vision language models')
296    parser.add_argument('--model-type',
297                        '-m',
298                        type=str,
299                        default="llava",
300                        choices=model_example_map.keys(),
301                        help='Huggingface "model_type".')
302    parser.add_argument('--num-prompts',
303                        type=int,
304                        default=4,
305                        help='Number of prompts to run.')
306    parser.add_argument('--modality',
307                        type=str,
308                        default="image",
309                        help='Modality of the input.')
310    parser.add_argument('--num-frames',
311                        type=int,
312                        default=16,
313                        help='Number of frames to extract from the video.')
314    args = parser.parse_args()
315    main(args)