Offline Inference Vision Language Embedding

Offline Inference Vision Language Embedding#

Source vllm-project/vllm.

  1"""
  2This example shows how to use vLLM for running offline inference with
  3the correct prompt format on vision language models for multimodal embedding.
  4
  5For most models, the prompt format should follow corresponding examples
  6on HuggingFace model repository.
  7"""
  8from argparse import Namespace
  9from typing import Literal, NamedTuple, Optional, TypedDict, Union, get_args
 10
 11from PIL.Image import Image
 12
 13from vllm import LLM
 14from vllm.multimodal.utils import fetch_image
 15from vllm.utils import FlexibleArgumentParser
 16
 17
 18class TextQuery(TypedDict):
 19    modality: Literal["text"]
 20    text: str
 21
 22
 23class ImageQuery(TypedDict):
 24    modality: Literal["image"]
 25    image: Image
 26
 27
 28class TextImageQuery(TypedDict):
 29    modality: Literal["text+image"]
 30    text: str
 31    image: Image
 32
 33
 34QueryModality = Literal["text", "image", "text+image"]
 35Query = Union[TextQuery, ImageQuery, TextImageQuery]
 36
 37
 38class ModelRequestData(NamedTuple):
 39    llm: LLM
 40    prompt: str
 41    image: Optional[Image]
 42
 43
 44def run_e5_v(query: Query):
 45    llama3_template = '<|start_header_id|>user<|end_header_id|>\n\n{}<|eot_id|><|start_header_id|>assistant<|end_header_id|>\n\n \n'  # noqa: E501
 46
 47    if query["modality"] == "text":
 48        text = query["text"]
 49        prompt = llama3_template.format(
 50            f"{text}\nSummary above sentence in one word: ")
 51        image = None
 52    elif query["modality"] == "image":
 53        prompt = llama3_template.format(
 54            "<image>\nSummary above image in one word: ")
 55        image = query["image"]
 56    else:
 57        modality = query['modality']
 58        raise ValueError(f"Unsupported query modality: '{modality}'")
 59
 60    llm = LLM(
 61        model="royokong/e5-v",
 62        task="embedding",
 63        max_model_len=4096,
 64    )
 65
 66    return ModelRequestData(
 67        llm=llm,
 68        prompt=prompt,
 69        image=image,
 70    )
 71
 72
 73def run_vlm2vec(query: Query):
 74    if query["modality"] == "text":
 75        text = query["text"]
 76        prompt = f"Find me an everyday image that matches the given caption: {text}"  # noqa: E501
 77        image = None
 78    elif query["modality"] == "image":
 79        prompt = "<|image_1|> Find a day-to-day image that looks similar to the provided image."  # noqa: E501
 80        image = query["image"]
 81    elif query["modality"] == "text+image":
 82        text = query["text"]
 83        prompt = f"<|image_1|> Represent the given image with the following question: {text}"  # noqa: E501
 84        image = query["image"]
 85    else:
 86        modality = query['modality']
 87        raise ValueError(f"Unsupported query modality: '{modality}'")
 88
 89    llm = LLM(
 90        model="TIGER-Lab/VLM2Vec-Full",
 91        task="embedding",
 92        trust_remote_code=True,
 93        mm_processor_kwargs={"num_crops": 4},
 94    )
 95
 96    return ModelRequestData(
 97        llm=llm,
 98        prompt=prompt,
 99        image=image,
100    )
101
102
103def get_query(modality: QueryModality):
104    if modality == "text":
105        return TextQuery(modality="text", text="A dog sitting in the grass")
106
107    if modality == "image":
108        return ImageQuery(
109            modality="image",
110            image=fetch_image(
111                "https://upload.wikimedia.org/wikipedia/commons/thumb/4/47/American_Eskimo_Dog.jpg/360px-American_Eskimo_Dog.jpg"  # noqa: E501
112            ),
113        )
114
115    if modality == "text+image":
116        return TextImageQuery(
117            modality="text+image",
118            text="A cat standing in the snow.",
119            image=fetch_image(
120                "https://upload.wikimedia.org/wikipedia/commons/thumb/b/b6/Felis_catus-cat_on_snow.jpg/179px-Felis_catus-cat_on_snow.jpg"  # noqa: E501
121            ),
122        )
123
124    msg = f"Modality {modality} is not supported."
125    raise ValueError(msg)
126
127
128def run_encode(model: str, modality: QueryModality):
129    query = get_query(modality)
130    req_data = model_example_map[model](query)
131
132    mm_data = {}
133    if req_data.image is not None:
134        mm_data["image"] = req_data.image
135
136    outputs = req_data.llm.encode({
137        "prompt": req_data.prompt,
138        "multi_modal_data": mm_data,
139    })
140
141    for output in outputs:
142        print(output.outputs.embedding)
143
144
145def main(args: Namespace):
146    run_encode(args.model_name, args.modality)
147
148
149model_example_map = {
150    "e5_v": run_e5_v,
151    "vlm2vec": run_vlm2vec,
152}
153
154if __name__ == "__main__":
155    parser = FlexibleArgumentParser(
156        description='Demo on using vLLM for offline inference with '
157        'vision language models for multimodal embedding')
158    parser.add_argument('--model-name',
159                        '-m',
160                        type=str,
161                        default="vlm2vec",
162                        choices=model_example_map.keys(),
163                        help='The name of the embedding model.')
164    parser.add_argument('--modality',
165                        type=str,
166                        default="image",
167                        choices=get_args(QueryModality),
168                        help='Modality of the input.')
169    args = parser.parse_args()
170    main(args)