Source code for vllm.multimodal.registry
import functools
from typing import Dict, Optional, Sequence
import torch
from vllm.config import ModelConfig
from vllm.logger import init_logger
from .base import (MultiModalDataDict, MultiModalInputMapper, MultiModalInputs,
MultiModalPlugin, MultiModalTokensCalc)
from .image import ImagePlugin
logger = init_logger(__name__)
[docs]class MultiModalRegistry:
"""
A registry to dispatch data processing
according to its modality and the target model.
The registry handles both external and internal data input.
"""
DEFAULT_PLUGINS = (ImagePlugin(), )
def __init__(
self,
*,
plugins: Sequence[MultiModalPlugin] = DEFAULT_PLUGINS) -> None:
self._plugins = {p.get_data_key(): p for p in plugins}
def register_plugin(self, plugin: MultiModalPlugin) -> None:
data_type_key = plugin.get_data_key()
if data_type_key in self._plugins:
logger.warning(
"A plugin is already registered for data type %s, "
"and will be overwritten by the new plugin %s.", data_type_key,
plugin)
self._plugins[data_type_key] = plugin
def _get_plugin(self, data_type_key: str):
plugin = self._plugins.get(data_type_key)
if plugin is not None:
return plugin
msg = f"Unknown multi-modal data type: {data_type_key}"
raise NotImplementedError(msg)
[docs] def register_input_mapper(
self,
data_type_key: str,
mapper: Optional[MultiModalInputMapper] = None,
):
"""
Register an input mapper for a specific modality to a model class.
See :meth:`MultiModalPlugin.register_input_mapper` for more details.
"""
return self._get_plugin(data_type_key).register_input_mapper(mapper)
[docs] def register_image_input_mapper(
self,
mapper: Optional[MultiModalInputMapper] = None,
):
"""
Register an input mapper for image data to a model class.
See :meth:`MultiModalPlugin.register_input_mapper` for more details.
"""
return self.register_input_mapper("image", mapper)
[docs] def map_input(self, model_config: ModelConfig,
data: MultiModalDataDict) -> MultiModalInputs:
"""
Apply an input mapper to the data passed to the model.
See :meth:`MultiModalPlugin.map_input` for more details.
"""
merged_dict: Dict[str, torch.Tensor] = {}
for data_key, data_value in data.items():
input_dict = self._get_plugin(data_key) \
.map_input(model_config, data_value)
for input_key, input_tensor in input_dict.items():
if input_key in merged_dict:
raise ValueError(f"The input mappers (keys={set(data)}) "
f"resulted in a conflicting keyword "
f"argument to `forward()`: {input_key}")
merged_dict[input_key] = input_tensor
return MultiModalInputs(merged_dict)
[docs] def create_input_mapper(self, model_config: ModelConfig):
"""
Create an input mapper (see :meth:`map_input`) for a specific model.
"""
return functools.partial(self.map_input, model_config)
[docs] def register_max_multimodal_tokens(
self,
data_type_key: str,
max_mm_tokens: Optional[MultiModalTokensCalc] = None,
):
"""
Register the maximum number of tokens, belonging to a
specific modality, input to the language model for a model class.
"""
return self._get_plugin(data_type_key) \
.register_max_multimodal_tokens(max_mm_tokens)
[docs] def register_max_image_tokens(
self,
max_mm_tokens: Optional[MultiModalTokensCalc] = None,
):
"""
Register the maximum number of image tokens
input to the language model for a model class.
"""
return self.register_max_multimodal_tokens("image", max_mm_tokens)
[docs] def get_max_multimodal_tokens(self, model_config: ModelConfig) -> int:
"""
Get the maximum number of multi-modal tokens
for profiling the memory usage of a model.
See :meth:`MultiModalPlugin.get_max_multimodal_tokens` for more details.
"""
return sum(
plugin.get_max_multimodal_tokens(model_config)
for plugin in self._plugins.values())