Offline Inference Vision Language Multi Image#

Source vllm-project/vllm.

  1"""
  2This example shows how to use vLLM for running offline inference with
  3multi-image input on vision language models, using the chat template defined
  4by the model.
  5"""
  6from argparse import Namespace
  7from typing import List
  8
  9from transformers import AutoProcessor, AutoTokenizer
 10
 11from vllm import LLM, SamplingParams
 12from vllm.multimodal.utils import fetch_image
 13from vllm.utils import FlexibleArgumentParser
 14
 15QUESTION = "What is the content of each image?"
 16IMAGE_URLS = [
 17    "https://upload.wikimedia.org/wikipedia/commons/d/da/2015_Kaczka_krzy%C5%BCowka_w_wodzie_%28samiec%29.jpg",
 18    "https://upload.wikimedia.org/wikipedia/commons/7/77/002_The_lion_king_Snyggve_in_the_Serengeti_National_Park_Photo_by_Giles_Laurent.jpg",
 19]
 20
 21
 22def load_qwenvl_chat(question: str, image_urls: List[str]):
 23    model_name = "Qwen/Qwen-VL-Chat"
 24    llm = LLM(
 25        model=model_name,
 26        trust_remote_code=True,
 27        max_num_seqs=5,
 28        limit_mm_per_prompt={"image": len(image_urls)},
 29    )
 30    placeholders = "".join(f"Picture {i}: <img></img>\n"
 31                           for i, _ in enumerate(image_urls, start=1))
 32
 33    # This model does not have a chat_template attribute on its tokenizer,
 34    # so we need to explicitly pass it. We use ChatML since it's used in the
 35    # generation utils of the model:
 36    # https://huggingface.co/Qwen/Qwen-VL-Chat/blob/main/qwen_generation_utils.py#L265
 37    tokenizer = AutoTokenizer.from_pretrained(model_name,
 38                                              trust_remote_code=True)
 39
 40    # Copied from: https://huggingface.co/docs/transformers/main/en/chat_templating
 41    chat_template = "{% if not add_generation_prompt is defined %}{% set add_generation_prompt = false %}{% endif %}{% for message in messages %}{{'<|im_start|>' + message['role'] + '\n' + message['content'] + '<|im_end|>' + '\n'}}{% endfor %}{% if add_generation_prompt %}{{ '<|im_start|>assistant\n' }}{% endif %}"  # noqa: E501
 42
 43    messages = [{'role': 'user', 'content': f"{placeholders}\n{question}"}]
 44    prompt = tokenizer.apply_chat_template(messages,
 45                                           tokenize=False,
 46                                           add_generation_prompt=True,
 47                                           chat_template=chat_template)
 48
 49    stop_tokens = ["<|endoftext|>", "<|im_start|>", "<|im_end|>"]
 50    stop_token_ids = [tokenizer.convert_tokens_to_ids(i) for i in stop_tokens]
 51    return llm, prompt, stop_token_ids, None, chat_template
 52
 53
 54def load_phi3v(question: str, image_urls: List[str]):
 55    llm = LLM(
 56        model="microsoft/Phi-3.5-vision-instruct",
 57        trust_remote_code=True,
 58        max_model_len=4096,
 59        limit_mm_per_prompt={"image": len(image_urls)},
 60    )
 61    placeholders = "\n".join(f"<|image_{i}|>"
 62                             for i, _ in enumerate(image_urls, start=1))
 63    prompt = f"<|user|>\n{placeholders}\n{question}<|end|>\n<|assistant|>\n"
 64    stop_token_ids = None
 65    return llm, prompt, stop_token_ids, None, None
 66
 67
 68def load_internvl(question: str, image_urls: List[str]):
 69    model_name = "OpenGVLab/InternVL2-2B"
 70
 71    llm = LLM(
 72        model=model_name,
 73        trust_remote_code=True,
 74        max_num_seqs=5,
 75        max_model_len=4096,
 76        limit_mm_per_prompt={"image": len(image_urls)},
 77    )
 78
 79    placeholders = "\n".join(f"Image-{i}: <image>\n"
 80                             for i, _ in enumerate(image_urls, start=1))
 81    messages = [{'role': 'user', 'content': f"{placeholders}\n{question}"}]
 82
 83    tokenizer = AutoTokenizer.from_pretrained(model_name,
 84                                              trust_remote_code=True)
 85    prompt = tokenizer.apply_chat_template(messages,
 86                                           tokenize=False,
 87                                           add_generation_prompt=True)
 88
 89    # Stop tokens for InternVL
 90    # models variants may have different stop tokens
 91    # please refer to the model card for the correct "stop words":
 92    # https://huggingface.co/OpenGVLab/InternVL2-2B#service
 93    stop_tokens = ["<|endoftext|>", "<|im_start|>", "<|im_end|>", "<|end|>"]
 94    stop_token_ids = [tokenizer.convert_tokens_to_ids(i) for i in stop_tokens]
 95
 96    return llm, prompt, stop_token_ids, None, None
 97
 98
 99def load_qwen2_vl(question, image_urls: List[str]):
100    try:
101        from qwen_vl_utils import process_vision_info
102    except ModuleNotFoundError:
103        print('WARNING: `qwen-vl-utils` not installed, input images will not '
104              'be automatically resized. You can enable this functionality by '
105              '`pip install qwen-vl-utils`.')
106        process_vision_info = None
107
108    model_name = "Qwen/Qwen2-VL-7B-Instruct"
109
110    llm = LLM(
111        model=model_name,
112        max_num_seqs=5,
113        max_model_len=32768 if process_vision_info is None else 4096,
114        limit_mm_per_prompt={"image": len(image_urls)},
115    )
116
117    placeholders = [{"type": "image", "image": url} for url in image_urls]
118    messages = [{
119        "role": "system",
120        "content": "You are a helpful assistant."
121    }, {
122        "role":
123        "user",
124        "content": [
125            *placeholders,
126            {
127                "type": "text",
128                "text": question
129            },
130        ],
131    }]
132
133    processor = AutoProcessor.from_pretrained(model_name)
134
135    prompt = processor.apply_chat_template(messages,
136                                           tokenize=False,
137                                           add_generation_prompt=True)
138
139    stop_token_ids = None
140
141    if process_vision_info is None:
142        image_data = [fetch_image(url) for url in image_urls]
143    else:
144        image_data, _ = process_vision_info(messages)
145
146    return llm, prompt, stop_token_ids, image_data, None
147
148
149model_example_map = {
150    "phi3_v": load_phi3v,
151    "internvl_chat": load_internvl,
152    "qwen2_vl": load_qwen2_vl,
153    "qwen_vl_chat": load_qwenvl_chat,
154}
155
156
157def run_generate(model, question: str, image_urls: List[str]):
158    llm, prompt, stop_token_ids, image_data, _ = model_example_map[model](
159        question, image_urls)
160    if image_data is None:
161        image_data = [fetch_image(url) for url in image_urls]
162
163    sampling_params = SamplingParams(temperature=0.0,
164                                     max_tokens=128,
165                                     stop_token_ids=stop_token_ids)
166
167    outputs = llm.generate(
168        {
169            "prompt": prompt,
170            "multi_modal_data": {
171                "image": image_data
172            },
173        },
174        sampling_params=sampling_params)
175
176    for o in outputs:
177        generated_text = o.outputs[0].text
178        print(generated_text)
179
180
181def run_chat(model: str, question: str, image_urls: List[str]):
182    llm, _, stop_token_ids, _, chat_template = model_example_map[model](
183        question, image_urls)
184
185    sampling_params = SamplingParams(temperature=0.0,
186                                     max_tokens=128,
187                                     stop_token_ids=stop_token_ids)
188    outputs = llm.chat(
189        [{
190            "role":
191            "user",
192            "content": [
193                {
194                    "type": "text",
195                    "text": question,
196                },
197                *({
198                    "type": "image_url",
199                    "image_url": {
200                        "url": image_url
201                    },
202                } for image_url in image_urls),
203            ],
204        }],
205        sampling_params=sampling_params,
206        chat_template=chat_template,
207    )
208
209    for o in outputs:
210        generated_text = o.outputs[0].text
211        print(generated_text)
212
213
214def main(args: Namespace):
215    model = args.model_type
216    method = args.method
217
218    if method == "generate":
219        run_generate(model, QUESTION, IMAGE_URLS)
220    elif method == "chat":
221        run_chat(model, QUESTION, IMAGE_URLS)
222    else:
223        raise ValueError(f"Invalid method: {method}")
224
225
226if __name__ == "__main__":
227    parser = FlexibleArgumentParser(
228        description='Demo on using vLLM for offline inference with '
229        'vision language models that support multi-image input')
230    parser.add_argument('--model-type',
231                        '-m',
232                        type=str,
233                        default="phi3_v",
234                        choices=model_example_map.keys(),
235                        help='Huggingface "model_type".')
236    parser.add_argument("--method",
237                        type=str,
238                        default="generate",
239                        choices=["generate", "chat"],
240                        help="The method to run in `vllm.LLM`.")
241
242    args = parser.parse_args()
243    main(args)