Offline Inference Vision Language Multi Image#

Source vllm-project/vllm.

  1"""
  2This example shows how to use vLLM for running offline inference with
  3multi-image input on vision language models, using the chat template defined
  4by the model.
  5"""
  6from argparse import Namespace
  7from typing import List, NamedTuple, Optional
  8
  9from PIL.Image import Image
 10from transformers import AutoProcessor, AutoTokenizer
 11
 12from vllm import LLM, SamplingParams
 13from vllm.multimodal.utils import fetch_image
 14from vllm.utils import FlexibleArgumentParser
 15
 16QUESTION = "What is the content of each image?"
 17IMAGE_URLS = [
 18    "https://upload.wikimedia.org/wikipedia/commons/d/da/2015_Kaczka_krzy%C5%BCowka_w_wodzie_%28samiec%29.jpg",
 19    "https://upload.wikimedia.org/wikipedia/commons/7/77/002_The_lion_king_Snyggve_in_the_Serengeti_National_Park_Photo_by_Giles_Laurent.jpg",
 20]
 21
 22
 23class ModelRequestData(NamedTuple):
 24    llm: LLM
 25    prompt: str
 26    stop_token_ids: Optional[List[str]]
 27    image_data: List[Image]
 28    chat_template: Optional[str]
 29
 30
 31def load_qwenvl_chat(question: str, image_urls: List[str]) -> ModelRequestData:
 32    model_name = "Qwen/Qwen-VL-Chat"
 33    llm = LLM(
 34        model=model_name,
 35        trust_remote_code=True,
 36        max_num_seqs=5,
 37        limit_mm_per_prompt={"image": len(image_urls)},
 38    )
 39    placeholders = "".join(f"Picture {i}: <img></img>\n"
 40                           for i, _ in enumerate(image_urls, start=1))
 41
 42    # This model does not have a chat_template attribute on its tokenizer,
 43    # so we need to explicitly pass it. We use ChatML since it's used in the
 44    # generation utils of the model:
 45    # https://huggingface.co/Qwen/Qwen-VL-Chat/blob/main/qwen_generation_utils.py#L265
 46    tokenizer = AutoTokenizer.from_pretrained(model_name,
 47                                              trust_remote_code=True)
 48
 49    # Copied from: https://huggingface.co/docs/transformers/main/en/chat_templating
 50    chat_template = "{% if not add_generation_prompt is defined %}{% set add_generation_prompt = false %}{% endif %}{% for message in messages %}{{'<|im_start|>' + message['role'] + '\n' + message['content'] + '<|im_end|>' + '\n'}}{% endfor %}{% if add_generation_prompt %}{{ '<|im_start|>assistant\n' }}{% endif %}"  # noqa: E501
 51
 52    messages = [{'role': 'user', 'content': f"{placeholders}\n{question}"}]
 53    prompt = tokenizer.apply_chat_template(messages,
 54                                           tokenize=False,
 55                                           add_generation_prompt=True,
 56                                           chat_template=chat_template)
 57
 58    stop_tokens = ["<|endoftext|>", "<|im_start|>", "<|im_end|>"]
 59    stop_token_ids = [tokenizer.convert_tokens_to_ids(i) for i in stop_tokens]
 60    return ModelRequestData(
 61        llm=llm,
 62        prompt=prompt,
 63        stop_token_ids=stop_token_ids,
 64        image_data=[fetch_image(url) for url in image_urls],
 65        chat_template=chat_template,
 66    )
 67
 68
 69def load_phi3v(question: str, image_urls: List[str]) -> ModelRequestData:
 70    # num_crops is an override kwarg to the multimodal image processor;
 71    # For some models, e.g., Phi-3.5-vision-instruct, it is recommended
 72    # to use 16 for single frame scenarios, and 4 for multi-frame.
 73    #
 74    # Generally speaking, a larger value for num_crops results in more
 75    # tokens per image instance, because it may scale the image more in
 76    # the image preprocessing. Some references in the model docs and the
 77    # formula for image tokens after the preprocessing
 78    # transform can be found below.
 79    #
 80    # https://huggingface.co/microsoft/Phi-3.5-vision-instruct#loading-the-model-locally
 81    # https://huggingface.co/microsoft/Phi-3.5-vision-instruct/blob/main/processing_phi3_v.py#L194
 82    llm = LLM(
 83        model="microsoft/Phi-3.5-vision-instruct",
 84        trust_remote_code=True,
 85        max_model_len=4096,
 86        limit_mm_per_prompt={"image": len(image_urls)},
 87        mm_processor_kwargs={"num_crops": 4},
 88    )
 89    placeholders = "\n".join(f"<|image_{i}|>"
 90                             for i, _ in enumerate(image_urls, start=1))
 91    prompt = f"<|user|>\n{placeholders}\n{question}<|end|>\n<|assistant|>\n"
 92    stop_token_ids = None
 93
 94    return ModelRequestData(
 95        llm=llm,
 96        prompt=prompt,
 97        stop_token_ids=stop_token_ids,
 98        image_data=[fetch_image(url) for url in image_urls],
 99        chat_template=None,
100    )
101
102
103def load_internvl(question: str, image_urls: List[str]) -> ModelRequestData:
104    model_name = "OpenGVLab/InternVL2-2B"
105
106    llm = LLM(
107        model=model_name,
108        trust_remote_code=True,
109        max_num_seqs=5,
110        max_model_len=4096,
111        limit_mm_per_prompt={"image": len(image_urls)},
112    )
113
114    placeholders = "\n".join(f"Image-{i}: <image>\n"
115                             for i, _ in enumerate(image_urls, start=1))
116    messages = [{'role': 'user', 'content': f"{placeholders}\n{question}"}]
117
118    tokenizer = AutoTokenizer.from_pretrained(model_name,
119                                              trust_remote_code=True)
120    prompt = tokenizer.apply_chat_template(messages,
121                                           tokenize=False,
122                                           add_generation_prompt=True)
123
124    # Stop tokens for InternVL
125    # models variants may have different stop tokens
126    # please refer to the model card for the correct "stop words":
127    # https://huggingface.co/OpenGVLab/InternVL2-2B#service
128    stop_tokens = ["<|endoftext|>", "<|im_start|>", "<|im_end|>", "<|end|>"]
129    stop_token_ids = [tokenizer.convert_tokens_to_ids(i) for i in stop_tokens]
130
131    return ModelRequestData(
132        llm=llm,
133        prompt=prompt,
134        stop_token_ids=stop_token_ids,
135        image_data=[fetch_image(url) for url in image_urls],
136        chat_template=None,
137    )
138
139
140def load_qwen2_vl(question, image_urls: List[str]) -> ModelRequestData:
141    try:
142        from qwen_vl_utils import process_vision_info
143    except ModuleNotFoundError:
144        print('WARNING: `qwen-vl-utils` not installed, input images will not '
145              'be automatically resized. You can enable this functionality by '
146              '`pip install qwen-vl-utils`.')
147        process_vision_info = None
148
149    model_name = "Qwen/Qwen2-VL-7B-Instruct"
150
151    llm = LLM(
152        model=model_name,
153        max_num_seqs=5,
154        max_model_len=32768 if process_vision_info is None else 4096,
155        limit_mm_per_prompt={"image": len(image_urls)},
156    )
157
158    placeholders = [{"type": "image", "image": url} for url in image_urls]
159    messages = [{
160        "role": "system",
161        "content": "You are a helpful assistant."
162    }, {
163        "role":
164        "user",
165        "content": [
166            *placeholders,
167            {
168                "type": "text",
169                "text": question
170            },
171        ],
172    }]
173
174    processor = AutoProcessor.from_pretrained(model_name)
175
176    prompt = processor.apply_chat_template(messages,
177                                           tokenize=False,
178                                           add_generation_prompt=True)
179
180    stop_token_ids = None
181
182    if process_vision_info is None:
183        image_data = [fetch_image(url) for url in image_urls]
184    else:
185        image_data, _ = process_vision_info(messages)
186
187    return ModelRequestData(
188        llm=llm,
189        prompt=prompt,
190        stop_token_ids=stop_token_ids,
191        image_data=image_data,
192        chat_template=None,
193    )
194
195
196model_example_map = {
197    "phi3_v": load_phi3v,
198    "internvl_chat": load_internvl,
199    "qwen2_vl": load_qwen2_vl,
200    "qwen_vl_chat": load_qwenvl_chat,
201}
202
203
204def run_generate(model, question: str, image_urls: List[str]):
205    req_data = model_example_map[model](question, image_urls)
206
207    sampling_params = SamplingParams(temperature=0.0,
208                                     max_tokens=128,
209                                     stop_token_ids=req_data.stop_token_ids)
210
211    outputs = req_data.llm.generate(
212        {
213            "prompt": req_data.prompt,
214            "multi_modal_data": {
215                "image": req_data.image_data
216            },
217        },
218        sampling_params=sampling_params)
219
220    for o in outputs:
221        generated_text = o.outputs[0].text
222        print(generated_text)
223
224
225def run_chat(model: str, question: str, image_urls: List[str]):
226    req_data = model_example_map[model](question, image_urls)
227
228    sampling_params = SamplingParams(temperature=0.0,
229                                     max_tokens=128,
230                                     stop_token_ids=req_data.stop_token_ids)
231    outputs = req_data.llm.chat(
232        [{
233            "role":
234            "user",
235            "content": [
236                {
237                    "type": "text",
238                    "text": question,
239                },
240                *({
241                    "type": "image_url",
242                    "image_url": {
243                        "url": image_url
244                    },
245                } for image_url in image_urls),
246            ],
247        }],
248        sampling_params=sampling_params,
249        chat_template=req_data.chat_template,
250    )
251
252    for o in outputs:
253        generated_text = o.outputs[0].text
254        print(generated_text)
255
256
257def main(args: Namespace):
258    model = args.model_type
259    method = args.method
260
261    if method == "generate":
262        run_generate(model, QUESTION, IMAGE_URLS)
263    elif method == "chat":
264        run_chat(model, QUESTION, IMAGE_URLS)
265    else:
266        raise ValueError(f"Invalid method: {method}")
267
268
269if __name__ == "__main__":
270    parser = FlexibleArgumentParser(
271        description='Demo on using vLLM for offline inference with '
272        'vision language models that support multi-image input')
273    parser.add_argument('--model-type',
274                        '-m',
275                        type=str,
276                        default="phi3_v",
277                        choices=model_example_map.keys(),
278                        help='Huggingface "model_type".')
279    parser.add_argument("--method",
280                        type=str,
281                        default="generate",
282                        choices=["generate", "chat"],
283                        help="The method to run in `vllm.LLM`.")
284
285    args = parser.parse_args()
286    main(args)