Offline Inference Vision Language Embedding#
Source vllm-project/vllm.
1"""
2This example shows how to use vLLM for running offline inference with
3the correct prompt format on vision language models for multimodal embedding.
4
5For most models, the prompt format should follow corresponding examples
6on HuggingFace model repository.
7"""
8from argparse import Namespace
9from typing import Literal, NamedTuple, Optional, TypedDict, Union, get_args
10
11from PIL.Image import Image
12
13from vllm import LLM
14from vllm.multimodal.utils import fetch_image
15from vllm.utils import FlexibleArgumentParser
16
17
18class TextQuery(TypedDict):
19 modality: Literal["text"]
20 text: str
21
22
23class ImageQuery(TypedDict):
24 modality: Literal["image"]
25 image: Image
26
27
28class TextImageQuery(TypedDict):
29 modality: Literal["text+image"]
30 text: str
31 image: Image
32
33
34QueryModality = Literal["text", "image", "text+image"]
35Query = Union[TextQuery, ImageQuery, TextImageQuery]
36
37
38class ModelRequestData(NamedTuple):
39 llm: LLM
40 prompt: str
41 image: Optional[Image]
42
43
44def run_e5_v(query: Query):
45 llama3_template = '<|start_header_id|>user<|end_header_id|>\n\n{}<|eot_id|><|start_header_id|>assistant<|end_header_id|>\n\n \n' # noqa: E501
46
47 if query["modality"] == "text":
48 text = query["text"]
49 prompt = llama3_template.format(
50 f"{text}\nSummary above sentence in one word: ")
51 image = None
52 elif query["modality"] == "image":
53 prompt = llama3_template.format(
54 "<image>\nSummary above image in one word: ")
55 image = query["image"]
56 else:
57 modality = query['modality']
58 raise ValueError(f"Unsupported query modality: '{modality}'")
59
60 llm = LLM(
61 model="royokong/e5-v",
62 task="embedding",
63 max_model_len=4096,
64 )
65
66 return ModelRequestData(
67 llm=llm,
68 prompt=prompt,
69 image=image,
70 )
71
72
73def run_vlm2vec(query: Query):
74 if query["modality"] == "text":
75 text = query["text"]
76 prompt = f"Find me an everyday image that matches the given caption: {text}" # noqa: E501
77 image = None
78 elif query["modality"] == "image":
79 prompt = "<|image_1|> Find a day-to-day image that looks similar to the provided image." # noqa: E501
80 image = query["image"]
81 elif query["modality"] == "text+image":
82 text = query["text"]
83 prompt = f"<|image_1|> Represent the given image with the following question: {text}" # noqa: E501
84 image = query["image"]
85 else:
86 modality = query['modality']
87 raise ValueError(f"Unsupported query modality: '{modality}'")
88
89 llm = LLM(
90 model="TIGER-Lab/VLM2Vec-Full",
91 task="embedding",
92 trust_remote_code=True,
93 mm_processor_kwargs={"num_crops": 4},
94 )
95
96 return ModelRequestData(
97 llm=llm,
98 prompt=prompt,
99 image=image,
100 )
101
102
103def get_query(modality: QueryModality):
104 if modality == "text":
105 return TextQuery(modality="text", text="A dog sitting in the grass")
106
107 if modality == "image":
108 return ImageQuery(
109 modality="image",
110 image=fetch_image(
111 "https://upload.wikimedia.org/wikipedia/commons/thumb/4/47/American_Eskimo_Dog.jpg/360px-American_Eskimo_Dog.jpg" # noqa: E501
112 ),
113 )
114
115 if modality == "text+image":
116 return TextImageQuery(
117 modality="text+image",
118 text="A cat standing in the snow.",
119 image=fetch_image(
120 "https://upload.wikimedia.org/wikipedia/commons/thumb/b/b6/Felis_catus-cat_on_snow.jpg/179px-Felis_catus-cat_on_snow.jpg" # noqa: E501
121 ),
122 )
123
124 msg = f"Modality {modality} is not supported."
125 raise ValueError(msg)
126
127
128def run_encode(model: str, modality: QueryModality):
129 query = get_query(modality)
130 req_data = model_example_map[model](query)
131
132 mm_data = {}
133 if req_data.image is not None:
134 mm_data["image"] = req_data.image
135
136 outputs = req_data.llm.encode({
137 "prompt": req_data.prompt,
138 "multi_modal_data": mm_data,
139 })
140
141 for output in outputs:
142 print(output.outputs.embedding)
143
144
145def main(args: Namespace):
146 run_encode(args.model_name, args.modality)
147
148
149model_example_map = {
150 "e5_v": run_e5_v,
151 "vlm2vec": run_vlm2vec,
152}
153
154if __name__ == "__main__":
155 parser = FlexibleArgumentParser(
156 description='Demo on using vLLM for offline inference with '
157 'vision language models for multimodal embedding')
158 parser.add_argument('--model-name',
159 '-m',
160 type=str,
161 default="vlm2vec",
162 choices=model_example_map.keys(),
163 help='The name of the embedding model.')
164 parser.add_argument('--modality',
165 type=str,
166 default="image",
167 choices=get_args(QueryModality),
168 help='Modality of the input.')
169 args = parser.parse_args()
170 main(args)