Offline Inference Vision Language#

Source vllm-project/vllm.

  1"""
  2This example shows how to use vLLM for running offline inference 
  3with the correct prompt format on vision language models.
  4
  5For most models, the prompt format should follow corresponding examples
  6on HuggingFace model repository.
  7"""
  8from transformers import AutoTokenizer
  9
 10from vllm import LLM, SamplingParams
 11from vllm.assets.image import ImageAsset
 12from vllm.utils import FlexibleArgumentParser
 13
 14# Input image and question
 15image = ImageAsset("cherry_blossom").pil_image.convert("RGB")
 16question = "What is the content of this image?"
 17
 18
 19# LLaVA-1.5
 20def run_llava(question):
 21
 22    prompt = f"USER: <image>\n{question}\nASSISTANT:"
 23
 24    llm = LLM(model="llava-hf/llava-1.5-7b-hf")
 25    stop_token_ids = None
 26    return llm, prompt, stop_token_ids
 27
 28
 29# LLaVA-1.6/LLaVA-NeXT
 30def run_llava_next(question):
 31
 32    prompt = f"[INST] <image>\n{question} [/INST]"
 33    llm = LLM(model="llava-hf/llava-v1.6-mistral-7b-hf")
 34    stop_token_ids = None
 35    return llm, prompt, stop_token_ids
 36
 37
 38# Fuyu
 39def run_fuyu(question):
 40
 41    prompt = f"{question}\n"
 42    llm = LLM(model="adept/fuyu-8b")
 43    stop_token_ids = None
 44    return llm, prompt, stop_token_ids
 45
 46
 47# Phi-3-Vision
 48def run_phi3v(question):
 49
 50    prompt = f"<|user|>\n<|image_1|>\n{question}<|end|>\n<|assistant|>\n"  # noqa: E501
 51    # Note: The default setting of max_num_seqs (256) and
 52    # max_model_len (128k) for this model may cause OOM.
 53    # You may lower either to run this example on lower-end GPUs.
 54
 55    # In this example, we override max_num_seqs to 5 while
 56    # keeping the original context length of 128k.
 57    llm = LLM(
 58        model="microsoft/Phi-3-vision-128k-instruct",
 59        trust_remote_code=True,
 60        max_num_seqs=5,
 61    )
 62    stop_token_ids = None
 63    return llm, prompt, stop_token_ids
 64
 65
 66# PaliGemma
 67def run_paligemma(question):
 68
 69    # PaliGemma has special prompt format for VQA
 70    prompt = "caption en"
 71    llm = LLM(model="google/paligemma-3b-mix-224")
 72    stop_token_ids = None
 73    return llm, prompt, stop_token_ids
 74
 75
 76# Chameleon
 77def run_chameleon(question):
 78
 79    prompt = f"{question}<image>"
 80    llm = LLM(model="facebook/chameleon-7b")
 81    stop_token_ids = None
 82    return llm, prompt, stop_token_ids
 83
 84
 85# MiniCPM-V
 86def run_minicpmv(question):
 87
 88    # 2.0
 89    # The official repo doesn't work yet, so we need to use a fork for now
 90    # For more details, please see: See: https://github.com/vllm-project/vllm/pull/4087#issuecomment-2250397630 # noqa
 91    # model_name = "HwwwH/MiniCPM-V-2"
 92
 93    # 2.5
 94    # model_name = "openbmb/MiniCPM-Llama3-V-2_5"
 95
 96    #2.6
 97    model_name = "openbmb/MiniCPM-V-2_6"
 98    tokenizer = AutoTokenizer.from_pretrained(model_name,
 99                                              trust_remote_code=True)
100    llm = LLM(
101        model=model_name,
102        trust_remote_code=True,
103    )
104    # NOTE The stop_token_ids are different for various versions of MiniCPM-V
105    # 2.0
106    # stop_token_ids = [tokenizer.eos_id]
107
108    # 2.5
109    # stop_token_ids = [tokenizer.eos_id, tokenizer.eot_id]
110
111    # 2.6
112    stop_tokens = ['<|im_end|>', '<|endoftext|>']
113    stop_token_ids = [tokenizer.convert_tokens_to_ids(i) for i in stop_tokens]
114
115    messages = [{
116        'role': 'user',
117        'content': f'(<image>./</image>)\n{question}'
118    }]
119    prompt = tokenizer.apply_chat_template(messages,
120                                           tokenize=False,
121                                           add_generation_prompt=True)
122    return llm, prompt, stop_token_ids
123
124
125# InternVL
126def run_internvl(question):
127    model_name = "OpenGVLab/InternVL2-2B"
128
129    llm = LLM(
130        model=model_name,
131        trust_remote_code=True,
132        max_num_seqs=5,
133    )
134
135    tokenizer = AutoTokenizer.from_pretrained(model_name,
136                                              trust_remote_code=True)
137    messages = [{'role': 'user', 'content': f"<image>\n{question}"}]
138    prompt = tokenizer.apply_chat_template(messages,
139                                           tokenize=False,
140                                           add_generation_prompt=True)
141
142    # Stop tokens for InternVL
143    # models variants may have different stop tokens
144    # please refer to the model card for the correct "stop words":
145    # https://huggingface.co/OpenGVLab/InternVL2-2B#service
146    stop_tokens = ["<|endoftext|>", "<|im_start|>", "<|im_end|>", "<|end|>"]
147    stop_token_ids = [tokenizer.convert_tokens_to_ids(i) for i in stop_tokens]
148    return llm, prompt, stop_token_ids
149
150
151# BLIP-2
152def run_blip2(question):
153
154    # BLIP-2 prompt format is inaccurate on HuggingFace model repository.
155    # See https://huggingface.co/Salesforce/blip2-opt-2.7b/discussions/15#64ff02f3f8cf9e4f5b038262 #noqa
156    prompt = f"Question: {question} Answer:"
157    llm = LLM(model="Salesforce/blip2-opt-2.7b")
158    stop_token_ids = None
159    return llm, prompt, stop_token_ids
160
161
162model_example_map = {
163    "llava": run_llava,
164    "llava-next": run_llava_next,
165    "fuyu": run_fuyu,
166    "phi3_v": run_phi3v,
167    "paligemma": run_paligemma,
168    "chameleon": run_chameleon,
169    "minicpmv": run_minicpmv,
170    "blip-2": run_blip2,
171    "internvl_chat": run_internvl,
172}
173
174
175def main(args):
176    model = args.model_type
177    if model not in model_example_map:
178        raise ValueError(f"Model type {model} is not supported.")
179
180    llm, prompt, stop_token_ids = model_example_map[model](question)
181
182    # We set temperature to 0.2 so that outputs can be different
183    # even when all prompts are identical when running batch inference.
184    sampling_params = SamplingParams(temperature=0.2,
185                                     max_tokens=64,
186                                     stop_token_ids=stop_token_ids)
187
188    assert args.num_prompts > 0
189    if args.num_prompts == 1:
190        # Single inference
191        inputs = {
192            "prompt": prompt,
193            "multi_modal_data": {
194                "image": image
195            },
196        }
197
198    else:
199        # Batch inference
200        inputs = [{
201            "prompt": prompt,
202            "multi_modal_data": {
203                "image": image
204            },
205        } for _ in range(args.num_prompts)]
206
207    outputs = llm.generate(inputs, sampling_params=sampling_params)
208
209    for o in outputs:
210        generated_text = o.outputs[0].text
211        print(generated_text)
212
213
214if __name__ == "__main__":
215    parser = FlexibleArgumentParser(
216        description='Demo on using vLLM for offline inference with '
217        'vision language models')
218    parser.add_argument('--model-type',
219                        '-m',
220                        type=str,
221                        default="llava",
222                        choices=model_example_map.keys(),
223                        help='Huggingface "model_type".')
224    parser.add_argument('--num-prompts',
225                        type=int,
226                        default=1,
227                        help='Number of prompts to run.')
228
229    args = parser.parse_args()
230    main(args)