import functools
from collections import UserDict
from typing import Dict, Mapping, Optional, Sequence
from vllm.config import ModelConfig
from vllm.logger import init_logger
from .audio import AudioPlugin
from .base import (MultiModalDataDict, MultiModalInputMapper, MultiModalInputs,
MultiModalPlugin, MultiModalTokensCalc, NestedTensors)
from .image import ImagePlugin
from .video import VideoPlugin
logger = init_logger(__name__)
class _MultiModalLimits(UserDict):
"""
Wraps `_limits_by_model` for a more informative error message
when attempting to access a model that does not exist.
"""
def __getitem__(self, key: ModelConfig) -> Dict[str, int]:
try:
return super().__getitem__(key)
except KeyError as exc:
msg = (f"Cannot find `mm_limits` for model={key.model}. Did you "
"forget to call `init_mm_limits_per_prompt`?")
raise KeyError(msg) from exc
[docs]class MultiModalRegistry:
"""
A registry that dispatches data processing to the
:class:`~vllm.multimodal.MultiModalPlugin` for each modality.
"""
DEFAULT_PLUGINS = (ImagePlugin(), AudioPlugin(), VideoPlugin())
def __init__(
self,
*,
plugins: Sequence[MultiModalPlugin] = DEFAULT_PLUGINS) -> None:
self._plugins = {p.get_data_key(): p for p in plugins}
# This is used for non-multimodal models
self._disabled_limits_per_plugin = {k: 0 for k in self._plugins}
self._limits_by_model = _MultiModalLimits()
[docs] def register_plugin(self, plugin: MultiModalPlugin) -> None:
"""
Register a multi-modal plugin so it can be recognized by vLLM.
See also:
:ref:`adding_multimodal_plugin`
"""
data_type_key = plugin.get_data_key()
if data_type_key in self._plugins:
logger.warning(
"A plugin is already registered for data type %s, "
"and will be overwritten by the new plugin %s.", data_type_key,
plugin)
self._plugins[data_type_key] = plugin
def _get_plugin(self, data_type_key: str):
plugin = self._plugins.get(data_type_key)
if plugin is not None:
return plugin
msg = f"Unknown multi-modal data type: {data_type_key}"
raise NotImplementedError(msg)
[docs] def register_max_multimodal_tokens(
self,
data_type_key: str,
max_mm_tokens: Optional[MultiModalTokensCalc] = None,
):
"""
Register the maximum number of tokens, corresponding to a single
instance of multimodal data belonging to a specific modality, that are
passed to the language model for a model class.
"""
return self._get_plugin(data_type_key) \
.register_max_multimodal_tokens(max_mm_tokens)
[docs] def register_max_image_tokens(
self,
max_mm_tokens: Optional[MultiModalTokensCalc] = None,
):
"""
Register the maximum number of image tokens, corresponding to a single
image, that are passed to the language model for a model class.
"""
return self.register_max_multimodal_tokens("image", max_mm_tokens)
[docs] def get_max_multimodal_tokens(self, model_config: ModelConfig) -> int:
"""
Get the maximum number of multi-modal tokens
for profiling the memory usage of a model.
See :meth:`MultiModalPlugin.get_max_multimodal_tokens` for more details.
Note:
This should be called after :meth:`init_mm_limits_per_prompt`.
"""
limits_per_plugin = self._limits_by_model[model_config]
return sum((limits_per_plugin[key] *
plugin.get_max_multimodal_tokens(model_config))
for key, plugin in self._plugins.items())
[docs] def init_mm_limits_per_prompt(
self,
model_config: ModelConfig,
) -> None:
"""
Initialize the maximum number of multi-modal input instances for each
modality that are allowed per prompt for a model class.
"""
if model_config in self._limits_by_model:
logger.warning(
"`mm_limits` has already been set for model=%s, and will "
"be overwritten by the new values.", model_config.model)
multimodal_config = model_config.multimodal_config
if multimodal_config is None:
limits_per_plugin = self._disabled_limits_per_plugin
else:
config_limits_per_plugin = multimodal_config.limit_per_prompt
extra_keys = config_limits_per_plugin.keys() - self._plugins.keys()
if extra_keys:
logger.warning(
"Detected extra keys in `--limit-mm-per-prompt` which "
"are not registered as multi-modal plugins: %s. "
"They will be ignored.", extra_keys)
# NOTE: Currently the default is set to 1 for each plugin
# TODO: Automatically determine the limits based on budget
# once more models support multi-image inputs
limits_per_plugin = {
key: config_limits_per_plugin.get(key, 1)
for key in self._plugins
}
self._limits_by_model[model_config] = limits_per_plugin
[docs] def get_mm_limits_per_prompt(
self,
model_config: ModelConfig,
) -> Mapping[str, int]:
"""
Get the maximum number of multi-modal input instances for each modality
that are allowed per prompt for a model class.
Note:
This should be called after :meth:`init_mm_limits_per_prompt`.
"""
return self._limits_by_model[model_config]