Offline Inference Vision Language Multi Image#
Source vllm-project/vllm.
1"""
2This example shows how to use vLLM for running offline inference with
3multi-image input on vision language models, using the chat template defined
4by the model.
5"""
6from argparse import Namespace
7from typing import List, NamedTuple, Optional
8
9from PIL.Image import Image
10from transformers import AutoProcessor, AutoTokenizer
11
12from vllm import LLM, SamplingParams
13from vllm.multimodal.utils import fetch_image
14from vllm.utils import FlexibleArgumentParser
15
16QUESTION = "What is the content of each image?"
17IMAGE_URLS = [
18 "https://upload.wikimedia.org/wikipedia/commons/d/da/2015_Kaczka_krzy%C5%BCowka_w_wodzie_%28samiec%29.jpg",
19 "https://upload.wikimedia.org/wikipedia/commons/7/77/002_The_lion_king_Snyggve_in_the_Serengeti_National_Park_Photo_by_Giles_Laurent.jpg",
20]
21
22
23class ModelRequestData(NamedTuple):
24 llm: LLM
25 prompt: str
26 stop_token_ids: Optional[List[str]]
27 image_data: List[Image]
28 chat_template: Optional[str]
29
30
31def load_qwenvl_chat(question: str, image_urls: List[str]) -> ModelRequestData:
32 model_name = "Qwen/Qwen-VL-Chat"
33 llm = LLM(
34 model=model_name,
35 trust_remote_code=True,
36 max_num_seqs=5,
37 limit_mm_per_prompt={"image": len(image_urls)},
38 )
39 placeholders = "".join(f"Picture {i}: <img></img>\n"
40 for i, _ in enumerate(image_urls, start=1))
41
42 # This model does not have a chat_template attribute on its tokenizer,
43 # so we need to explicitly pass it. We use ChatML since it's used in the
44 # generation utils of the model:
45 # https://huggingface.co/Qwen/Qwen-VL-Chat/blob/main/qwen_generation_utils.py#L265
46 tokenizer = AutoTokenizer.from_pretrained(model_name,
47 trust_remote_code=True)
48
49 # Copied from: https://huggingface.co/docs/transformers/main/en/chat_templating
50 chat_template = "{% if not add_generation_prompt is defined %}{% set add_generation_prompt = false %}{% endif %}{% for message in messages %}{{'<|im_start|>' + message['role'] + '\n' + message['content'] + '<|im_end|>' + '\n'}}{% endfor %}{% if add_generation_prompt %}{{ '<|im_start|>assistant\n' }}{% endif %}" # noqa: E501
51
52 messages = [{'role': 'user', 'content': f"{placeholders}\n{question}"}]
53 prompt = tokenizer.apply_chat_template(messages,
54 tokenize=False,
55 add_generation_prompt=True,
56 chat_template=chat_template)
57
58 stop_tokens = ["<|endoftext|>", "<|im_start|>", "<|im_end|>"]
59 stop_token_ids = [tokenizer.convert_tokens_to_ids(i) for i in stop_tokens]
60 return ModelRequestData(
61 llm=llm,
62 prompt=prompt,
63 stop_token_ids=stop_token_ids,
64 image_data=[fetch_image(url) for url in image_urls],
65 chat_template=chat_template,
66 )
67
68
69def load_phi3v(question: str, image_urls: List[str]) -> ModelRequestData:
70 # num_crops is an override kwarg to the multimodal image processor;
71 # For some models, e.g., Phi-3.5-vision-instruct, it is recommended
72 # to use 16 for single frame scenarios, and 4 for multi-frame.
73 #
74 # Generally speaking, a larger value for num_crops results in more
75 # tokens per image instance, because it may scale the image more in
76 # the image preprocessing. Some references in the model docs and the
77 # formula for image tokens after the preprocessing
78 # transform can be found below.
79 #
80 # https://huggingface.co/microsoft/Phi-3.5-vision-instruct#loading-the-model-locally
81 # https://huggingface.co/microsoft/Phi-3.5-vision-instruct/blob/main/processing_phi3_v.py#L194
82 llm = LLM(
83 model="microsoft/Phi-3.5-vision-instruct",
84 trust_remote_code=True,
85 max_model_len=4096,
86 limit_mm_per_prompt={"image": len(image_urls)},
87 mm_processor_kwargs={"num_crops": 4},
88 )
89 placeholders = "\n".join(f"<|image_{i}|>"
90 for i, _ in enumerate(image_urls, start=1))
91 prompt = f"<|user|>\n{placeholders}\n{question}<|end|>\n<|assistant|>\n"
92 stop_token_ids = None
93
94 return ModelRequestData(
95 llm=llm,
96 prompt=prompt,
97 stop_token_ids=stop_token_ids,
98 image_data=[fetch_image(url) for url in image_urls],
99 chat_template=None,
100 )
101
102
103def load_internvl(question: str, image_urls: List[str]) -> ModelRequestData:
104 model_name = "OpenGVLab/InternVL2-2B"
105
106 llm = LLM(
107 model=model_name,
108 trust_remote_code=True,
109 max_num_seqs=5,
110 max_model_len=4096,
111 limit_mm_per_prompt={"image": len(image_urls)},
112 )
113
114 placeholders = "\n".join(f"Image-{i}: <image>\n"
115 for i, _ in enumerate(image_urls, start=1))
116 messages = [{'role': 'user', 'content': f"{placeholders}\n{question}"}]
117
118 tokenizer = AutoTokenizer.from_pretrained(model_name,
119 trust_remote_code=True)
120 prompt = tokenizer.apply_chat_template(messages,
121 tokenize=False,
122 add_generation_prompt=True)
123
124 # Stop tokens for InternVL
125 # models variants may have different stop tokens
126 # please refer to the model card for the correct "stop words":
127 # https://huggingface.co/OpenGVLab/InternVL2-2B#service
128 stop_tokens = ["<|endoftext|>", "<|im_start|>", "<|im_end|>", "<|end|>"]
129 stop_token_ids = [tokenizer.convert_tokens_to_ids(i) for i in stop_tokens]
130
131 return ModelRequestData(
132 llm=llm,
133 prompt=prompt,
134 stop_token_ids=stop_token_ids,
135 image_data=[fetch_image(url) for url in image_urls],
136 chat_template=None,
137 )
138
139
140def load_qwen2_vl(question, image_urls: List[str]) -> ModelRequestData:
141 try:
142 from qwen_vl_utils import process_vision_info
143 except ModuleNotFoundError:
144 print('WARNING: `qwen-vl-utils` not installed, input images will not '
145 'be automatically resized. You can enable this functionality by '
146 '`pip install qwen-vl-utils`.')
147 process_vision_info = None
148
149 model_name = "Qwen/Qwen2-VL-7B-Instruct"
150
151 llm = LLM(
152 model=model_name,
153 max_num_seqs=5,
154 max_model_len=32768 if process_vision_info is None else 4096,
155 limit_mm_per_prompt={"image": len(image_urls)},
156 )
157
158 placeholders = [{"type": "image", "image": url} for url in image_urls]
159 messages = [{
160 "role": "system",
161 "content": "You are a helpful assistant."
162 }, {
163 "role":
164 "user",
165 "content": [
166 *placeholders,
167 {
168 "type": "text",
169 "text": question
170 },
171 ],
172 }]
173
174 processor = AutoProcessor.from_pretrained(model_name)
175
176 prompt = processor.apply_chat_template(messages,
177 tokenize=False,
178 add_generation_prompt=True)
179
180 stop_token_ids = None
181
182 if process_vision_info is None:
183 image_data = [fetch_image(url) for url in image_urls]
184 else:
185 image_data, _ = process_vision_info(messages)
186
187 return ModelRequestData(
188 llm=llm,
189 prompt=prompt,
190 stop_token_ids=stop_token_ids,
191 image_data=image_data,
192 chat_template=None,
193 )
194
195
196model_example_map = {
197 "phi3_v": load_phi3v,
198 "internvl_chat": load_internvl,
199 "qwen2_vl": load_qwen2_vl,
200 "qwen_vl_chat": load_qwenvl_chat,
201}
202
203
204def run_generate(model, question: str, image_urls: List[str]):
205 req_data = model_example_map[model](question, image_urls)
206
207 sampling_params = SamplingParams(temperature=0.0,
208 max_tokens=128,
209 stop_token_ids=req_data.stop_token_ids)
210
211 outputs = req_data.llm.generate(
212 {
213 "prompt": req_data.prompt,
214 "multi_modal_data": {
215 "image": req_data.image_data
216 },
217 },
218 sampling_params=sampling_params)
219
220 for o in outputs:
221 generated_text = o.outputs[0].text
222 print(generated_text)
223
224
225def run_chat(model: str, question: str, image_urls: List[str]):
226 req_data = model_example_map[model](question, image_urls)
227
228 sampling_params = SamplingParams(temperature=0.0,
229 max_tokens=128,
230 stop_token_ids=req_data.stop_token_ids)
231 outputs = req_data.llm.chat(
232 [{
233 "role":
234 "user",
235 "content": [
236 {
237 "type": "text",
238 "text": question,
239 },
240 *({
241 "type": "image_url",
242 "image_url": {
243 "url": image_url
244 },
245 } for image_url in image_urls),
246 ],
247 }],
248 sampling_params=sampling_params,
249 chat_template=req_data.chat_template,
250 )
251
252 for o in outputs:
253 generated_text = o.outputs[0].text
254 print(generated_text)
255
256
257def main(args: Namespace):
258 model = args.model_type
259 method = args.method
260
261 if method == "generate":
262 run_generate(model, QUESTION, IMAGE_URLS)
263 elif method == "chat":
264 run_chat(model, QUESTION, IMAGE_URLS)
265 else:
266 raise ValueError(f"Invalid method: {method}")
267
268
269if __name__ == "__main__":
270 parser = FlexibleArgumentParser(
271 description='Demo on using vLLM for offline inference with '
272 'vision language models that support multi-image input')
273 parser.add_argument('--model-type',
274 '-m',
275 type=str,
276 default="phi3_v",
277 choices=model_example_map.keys(),
278 help='Huggingface "model_type".')
279 parser.add_argument("--method",
280 type=str,
281 default="generate",
282 choices=["generate", "chat"],
283 help="The method to run in `vllm.LLM`.")
284
285 args = parser.parse_args()
286 main(args)