Offline Inference Vision Language#

Source vllm-project/vllm.

  1"""
  2This example shows how to use vLLM for running offline inference 
  3with the correct prompt format on vision language models.
  4
  5For most models, the prompt format should follow corresponding examples
  6on HuggingFace model repository.
  7"""
  8from transformers import AutoTokenizer
  9
 10from vllm import LLM, SamplingParams
 11from vllm.assets.image import ImageAsset
 12from vllm.utils import FlexibleArgumentParser
 13
 14# Input image and question
 15image = ImageAsset("cherry_blossom").pil_image.convert("RGB")
 16question = "What is the content of this image?"
 17
 18
 19# LLaVA-1.5
 20def run_llava(question):
 21
 22    prompt = f"USER: <image>\n{question}\nASSISTANT:"
 23
 24    llm = LLM(model="llava-hf/llava-1.5-7b-hf")
 25
 26    return llm, prompt
 27
 28
 29# LLaVA-1.6/LLaVA-NeXT
 30def run_llava_next(question):
 31
 32    prompt = f"[INST] <image>\n{question} [/INST]"
 33    llm = LLM(model="llava-hf/llava-v1.6-mistral-7b-hf")
 34
 35    return llm, prompt
 36
 37
 38# Fuyu
 39def run_fuyu(question):
 40
 41    prompt = f"{question}\n"
 42    llm = LLM(model="adept/fuyu-8b")
 43
 44    return llm, prompt
 45
 46
 47# Phi-3-Vision
 48def run_phi3v(question):
 49
 50    prompt = f"<|user|>\n<|image_1|>\n{question}<|end|>\n<|assistant|>\n"  # noqa: E501
 51    # Note: The default setting of max_num_seqs (256) and
 52    # max_model_len (128k) for this model may cause OOM.
 53    # You may lower either to run this example on lower-end GPUs.
 54
 55    # In this example, we override max_num_seqs to 5 while
 56    # keeping the original context length of 128k.
 57    llm = LLM(
 58        model="microsoft/Phi-3-vision-128k-instruct",
 59        trust_remote_code=True,
 60        max_num_seqs=5,
 61    )
 62    return llm, prompt
 63
 64
 65# PaliGemma
 66def run_paligemma(question):
 67
 68    # PaliGemma has special prompt format for VQA
 69    prompt = "caption en"
 70    llm = LLM(model="google/paligemma-3b-mix-224")
 71
 72    return llm, prompt
 73
 74
 75# Chameleon
 76def run_chameleon(question):
 77
 78    prompt = f"{question}<image>"
 79    llm = LLM(model="facebook/chameleon-7b")
 80    return llm, prompt
 81
 82
 83# MiniCPM-V
 84def run_minicpmv(question):
 85
 86    # 2.0
 87    # The official repo doesn't work yet, so we need to use a fork for now
 88    # For more details, please see: See: https://github.com/vllm-project/vllm/pull/4087#issuecomment-2250397630 # noqa
 89    # model_name = "HwwwH/MiniCPM-V-2"
 90
 91    # 2.5
 92    model_name = "openbmb/MiniCPM-Llama3-V-2_5"
 93    tokenizer = AutoTokenizer.from_pretrained(model_name,
 94                                              trust_remote_code=True)
 95    llm = LLM(
 96        model=model_name,
 97        trust_remote_code=True,
 98    )
 99
100    messages = [{
101        'role': 'user',
102        'content': f'(<image>./</image>)\n{question}'
103    }]
104    prompt = tokenizer.apply_chat_template(messages,
105                                           tokenize=False,
106                                           add_generation_prompt=True)
107    return llm, prompt
108
109
110# InternVL
111def run_internvl(question):
112    # Generally, InternVL can use chatml template for conversation
113    TEMPLATE = "<|im_start|>User\n{prompt}<|im_end|>\n<|im_start|>Assistant\n"
114    prompt = f"<image>\n{question}\n"
115    prompt = TEMPLATE.format(prompt=prompt)
116    llm = LLM(
117        model="OpenGVLab/InternVL2-4B",
118        trust_remote_code=True,
119        max_num_seqs=5,
120    )
121    return llm, prompt
122
123
124# BLIP-2
125def run_blip2(question):
126
127    # BLIP-2 prompt format is inaccurate on HuggingFace model repository.
128    # See https://huggingface.co/Salesforce/blip2-opt-2.7b/discussions/15#64ff02f3f8cf9e4f5b038262 #noqa
129    prompt = f"Question: {question} Answer:"
130    llm = LLM(model="Salesforce/blip2-opt-2.7b")
131    return llm, prompt
132
133
134model_example_map = {
135    "llava": run_llava,
136    "llava-next": run_llava_next,
137    "fuyu": run_fuyu,
138    "phi3_v": run_phi3v,
139    "paligemma": run_paligemma,
140    "chameleon": run_chameleon,
141    "minicpmv": run_minicpmv,
142    "blip-2": run_blip2,
143    "internvl_chat": run_internvl,
144}
145
146
147def main(args):
148    model = args.model_type
149    if model not in model_example_map:
150        raise ValueError(f"Model type {model} is not supported.")
151
152    llm, prompt = model_example_map[model](question)
153
154    # We set temperature to 0.2 so that outputs can be different
155    # even when all prompts are identical when running batch inference.
156    sampling_params = SamplingParams(temperature=0.2, max_tokens=64)
157
158    assert args.num_prompts > 0
159    if args.num_prompts == 1:
160        # Single inference
161        inputs = {
162            "prompt": prompt,
163            "multi_modal_data": {
164                "image": image
165            },
166        }
167
168    else:
169        # Batch inference
170        inputs = [{
171            "prompt": prompt,
172            "multi_modal_data": {
173                "image": image
174            },
175        } for _ in range(args.num_prompts)]
176
177    outputs = llm.generate(inputs, sampling_params=sampling_params)
178
179    for o in outputs:
180        generated_text = o.outputs[0].text
181        print(generated_text)
182
183
184if __name__ == "__main__":
185    parser = FlexibleArgumentParser(
186        description='Demo on using vLLM for offline inference with '
187        'vision language models')
188    parser.add_argument('--model-type',
189                        '-m',
190                        type=str,
191                        default="llava",
192                        choices=model_example_map.keys(),
193                        help='Huggingface "model_type".')
194    parser.add_argument('--num-prompts',
195                        type=int,
196                        default=1,
197                        help='Number of prompts to run.')
198
199    args = parser.parse_args()
200    main(args)