Offline Inference Vision Language Multi Image#

Source vllm-project/vllm.

  1"""
  2This example shows how to use vLLM for running offline inference with
  3multi-image input on vision language models, using the chat template defined
  4by the model.
  5"""
  6from argparse import Namespace
  7from typing import List
  8
  9from transformers import AutoProcessor, AutoTokenizer
 10
 11from vllm import LLM, SamplingParams
 12from vllm.multimodal.utils import fetch_image
 13from vllm.utils import FlexibleArgumentParser
 14
 15QUESTION = "What is the content of each image?"
 16IMAGE_URLS = [
 17    "https://upload.wikimedia.org/wikipedia/commons/d/da/2015_Kaczka_krzy%C5%BCowka_w_wodzie_%28samiec%29.jpg",
 18    "https://upload.wikimedia.org/wikipedia/commons/7/77/002_The_lion_king_Snyggve_in_the_Serengeti_National_Park_Photo_by_Giles_Laurent.jpg",
 19]
 20
 21
 22def load_phi3v(question, image_urls: List[str]):
 23    llm = LLM(
 24        model="microsoft/Phi-3.5-vision-instruct",
 25        trust_remote_code=True,
 26        max_model_len=4096,
 27        limit_mm_per_prompt={"image": len(image_urls)},
 28    )
 29    placeholders = "\n".join(f"<|image_{i}|>"
 30                             for i, _ in enumerate(image_urls, start=1))
 31    prompt = f"<|user|>\n{placeholders}\n{question}<|end|>\n<|assistant|>\n"
 32    stop_token_ids = None
 33    return llm, prompt, stop_token_ids, None
 34
 35
 36def load_internvl(question, image_urls: List[str]):
 37    model_name = "OpenGVLab/InternVL2-2B"
 38
 39    llm = LLM(
 40        model=model_name,
 41        trust_remote_code=True,
 42        max_num_seqs=5,
 43        max_model_len=4096,
 44        limit_mm_per_prompt={"image": len(image_urls)},
 45    )
 46
 47    placeholders = "\n".join(f"Image-{i}: <image>\n"
 48                             for i, _ in enumerate(image_urls, start=1))
 49    messages = [{'role': 'user', 'content': f"{placeholders}\n{question}"}]
 50
 51    tokenizer = AutoTokenizer.from_pretrained(model_name,
 52                                              trust_remote_code=True)
 53    prompt = tokenizer.apply_chat_template(messages,
 54                                           tokenize=False,
 55                                           add_generation_prompt=True)
 56
 57    # Stop tokens for InternVL
 58    # models variants may have different stop tokens
 59    # please refer to the model card for the correct "stop words":
 60    # https://huggingface.co/OpenGVLab/InternVL2-2B#service
 61    stop_tokens = ["<|endoftext|>", "<|im_start|>", "<|im_end|>", "<|end|>"]
 62    stop_token_ids = [tokenizer.convert_tokens_to_ids(i) for i in stop_tokens]
 63
 64    return llm, prompt, stop_token_ids, None
 65
 66
 67def load_qwen2_vl(question, image_urls: List[str]):
 68    try:
 69        from qwen_vl_utils import process_vision_info
 70    except ModuleNotFoundError:
 71        print('WARNING: `qwen-vl-utils` not installed, input images will not '
 72              'be automatically resized. You can enable this functionality by '
 73              '`pip install qwen-vl-utils`.')
 74        process_vision_info = None
 75
 76    model_name = "Qwen/Qwen2-VL-7B-Instruct"
 77
 78    llm = LLM(
 79        model=model_name,
 80        max_num_seqs=5,
 81        max_model_len=32768 if process_vision_info is None else 4096,
 82        limit_mm_per_prompt={"image": len(image_urls)},
 83    )
 84
 85    placeholders = [{"type": "image", "image": url} for url in image_urls]
 86    messages = [{
 87        "role": "system",
 88        "content": "You are a helpful assistant."
 89    }, {
 90        "role":
 91        "user",
 92        "content": [
 93            *placeholders,
 94            {
 95                "type": "text",
 96                "text": question
 97            },
 98        ],
 99    }]
100
101    processor = AutoProcessor.from_pretrained(model_name)
102
103    prompt = processor.apply_chat_template(messages,
104                                           tokenize=False,
105                                           add_generation_prompt=True)
106
107    stop_token_ids = None
108
109    if process_vision_info is None:
110        image_data = [fetch_image(url) for url in image_urls]
111    else:
112        image_data, _ = process_vision_info(messages)
113
114    return llm, prompt, stop_token_ids, image_data
115
116
117model_example_map = {
118    "phi3_v": load_phi3v,
119    "internvl_chat": load_internvl,
120    "qwen2_vl": load_qwen2_vl,
121}
122
123
124def run_generate(model, question: str, image_urls: List[str]):
125    llm, prompt, stop_token_ids, image_data = model_example_map[model](
126        question, image_urls)
127    if image_data is None:
128        image_data = [fetch_image(url) for url in image_urls]
129
130    sampling_params = SamplingParams(temperature=0.0,
131                                     max_tokens=128,
132                                     stop_token_ids=stop_token_ids)
133
134    outputs = llm.generate(
135        {
136            "prompt": prompt,
137            "multi_modal_data": {
138                "image": image_data
139            },
140        },
141        sampling_params=sampling_params)
142
143    for o in outputs:
144        generated_text = o.outputs[0].text
145        print(generated_text)
146
147
148def run_chat(model: str, question: str, image_urls: List[str]):
149    llm, _, stop_token_ids, _ = model_example_map[model](question, image_urls)
150
151    sampling_params = SamplingParams(temperature=0.0,
152                                     max_tokens=128,
153                                     stop_token_ids=stop_token_ids)
154
155    outputs = llm.chat([{
156        "role":
157        "user",
158        "content": [
159            {
160                "type": "text",
161                "text": question,
162            },
163            *({
164                "type": "image_url",
165                "image_url": {
166                    "url": image_url
167                },
168            } for image_url in image_urls),
169        ],
170    }],
171                       sampling_params=sampling_params)
172
173    for o in outputs:
174        generated_text = o.outputs[0].text
175        print(generated_text)
176
177
178def main(args: Namespace):
179    model = args.model_type
180    method = args.method
181
182    if method == "generate":
183        run_generate(model, QUESTION, IMAGE_URLS)
184    elif method == "chat":
185        run_chat(model, QUESTION, IMAGE_URLS)
186    else:
187        raise ValueError(f"Invalid method: {method}")
188
189
190if __name__ == "__main__":
191    parser = FlexibleArgumentParser(
192        description='Demo on using vLLM for offline inference with '
193        'vision language models that support multi-image input')
194    parser.add_argument('--model-type',
195                        '-m',
196                        type=str,
197                        default="phi3_v",
198                        choices=model_example_map.keys(),
199                        help='Huggingface "model_type".')
200    parser.add_argument("--method",
201                        type=str,
202                        default="generate",
203                        choices=["generate", "chat"],
204                        help="The method to run in `vllm.LLM`.")
205
206    args = parser.parse_args()
207    main(args)