Source code for vllm.multimodal.parse

from abc import ABC, abstractmethod
from collections import UserDict
from collections.abc import Callable, Iterator, Mapping, Sequence
from typing import (TYPE_CHECKING, Any, Generic, NamedTuple, Optional, TypeVar,
                    Union)

import numpy as np
import torch
from PIL.Image import Image
from typing_extensions import TypeAlias, TypeGuard, assert_never

from vllm.utils import is_list_of

from .audio import resample_audio
from .inputs import (AudioItem, HfAudioItem, HfImageItem, HfVideoItem,
                     ImageItem, ModalityData, MultiModalDataDict, VideoItem)

_T = TypeVar("_T")
_I = TypeVar("_I")


[docs]class ModalityDataItems(ABC, Generic[_T, _I]): """ Represents data items for a modality in :class:`MultiModalDataItems`. """ def __init__(self, data: _T, modality: str) -> None: super().__init__() self.data = data self.modality = modality def __repr__(self) -> str: return (f"{type(self).__name__}(modality={self.modality!r}, " f"len={len(self)})") def __len__(self) -> int: return self.get_count() def __getitem__(self, index: int) -> _I: return self.get(index) if TYPE_CHECKING: # Auto-generated def __iter__(self) -> Iterator[_I]: ...
[docs] @abstractmethod def get_count(self) -> int: """Get the number of data items.""" raise NotImplementedError
[docs] @abstractmethod def get(self, index: int) -> _I: """Get a data item by its index.""" raise NotImplementedError
[docs] def get_all(self) -> list[_I]: """Get all data items.""" return [self.get(idx) for idx in range(self.get_count())]
[docs] @abstractmethod def get_processor_data(self) -> Mapping[str, object]: """Get the data to pass to the HF processor.""" raise NotImplementedError
[docs] @abstractmethod def get_passthrough_data(self) -> Mapping[str, object]: """Get the data to pass directly to the model.""" raise NotImplementedError
[docs]class ProcessorBatchItems(ModalityDataItems[Sequence[_T], _T]): """Base class for data items that are arranged in a list."""
[docs] def get_count(self) -> int: return len(self.data)
[docs] def get(self, index: int) -> _T: return self.data[index]
[docs] def get_processor_data(self) -> Mapping[str, object]: return {f"{self.modality}s": self.data}
[docs] def get_passthrough_data(self) -> Mapping[str, object]: return {}
[docs]class EmbeddingItems(ModalityDataItems[Union[torch.Tensor, list[torch.Tensor]], torch.Tensor]): """ Base class for data items that are expressed as a batched embedding tensor, or a list of embedding tensors (one per item). """
[docs] def get_count(self) -> int: return len(self.data)
[docs] def get(self, index: int) -> torch.Tensor: return self.data[index]
[docs] def get_processor_data(self) -> Mapping[str, object]: return {}
[docs] def get_passthrough_data(self) -> Mapping[str, object]: return {f"{self.modality}_embeds": self.data}
def get_feature_size(self, item_idx: int) -> int: return len(self.get(item_idx))
[docs]class AudioProcessorItems(ProcessorBatchItems[HfAudioItem]): def __init__(self, data: Sequence[HfAudioItem]) -> None: super().__init__(data, "audio")
[docs]class AudioEmbeddingItems(EmbeddingItems): def __init__(self, data: Union[torch.Tensor, list[torch.Tensor]]) -> None: super().__init__(data, "audio")
[docs]class ImageSize(NamedTuple): width: int height: int
[docs]class ImageProcessorItems(ProcessorBatchItems[HfImageItem]): def __init__(self, data: Sequence[HfImageItem]) -> None: super().__init__(data, "image") def get_image_size(self, item_idx: int) -> ImageSize: image = self.get(item_idx) if isinstance(image, Image): return ImageSize(*image.size) if isinstance(image, (np.ndarray, torch.Tensor)): _, h, w = image.shape return ImageSize(w, h) assert_never(image)
[docs]class ImageEmbeddingItems(EmbeddingItems): def __init__(self, data: Union[torch.Tensor, list[torch.Tensor]]) -> None: super().__init__(data, "image")
[docs]class VideoProcessorItems(ProcessorBatchItems[HfVideoItem]): def __init__(self, data: Sequence[HfVideoItem]) -> None: super().__init__(data, "video") def get_num_frames(self, item_idx: int) -> int: return len(self.get(item_idx)) def get_frame_size(self, item_idx: int) -> ImageSize: image = self.get(item_idx)[0] # Assume that the video isn't empty if isinstance(image, Image): return ImageSize(*image.size) if isinstance(image, (np.ndarray, torch.Tensor)): _, h, w = image.shape return ImageSize(w, h) assert_never(image)
[docs]class VideoEmbeddingItems(EmbeddingItems): def __init__(self, data: Union[torch.Tensor, list[torch.Tensor]]) -> None: super().__init__(data, "video")
_D = TypeVar("_D", bound=ModalityDataItems[Any, Any])
[docs]class MultiModalDataItems(UserDict[str, ModalityDataItems[Any, Any]]): """ As :data:`~vllm.multimodal.inputs.MultiModalDataDict`, but normalized such that each entry corresponds to a list. """
[docs] def get_count(self, modality: str, *, strict: bool = True) -> int: """ Get the number of data items belonging to a modality. If `strict=False`, return `0` instead of raising :exc:`KeyError` even if the modality is not found. """ if modality not in self: if strict: available_modalities = set(self.keys()) raise KeyError(f"Modality {modality!r} not found. " f"Available modalities: {available_modalities}") return 0 return self[modality].get_count()
[docs] def get_all_counts(self) -> Mapping[str, int]: """Get the number of items belonging to each modality.""" return {m: items.get_count() for m, items in self.items()}
[docs] def get_items( self, modality: str, typ: Union[type[_D], tuple[type[_D], ...]], ) -> _D: """ Get the data items belonging to a modality, requiring that they belong to a certain type. """ if modality not in self: available_modalities = set(self.keys()) raise KeyError(f"Modality {modality!r} not found. " f"Available modalities: {available_modalities}") items = self[modality] if not isinstance(items, typ): raise TypeError(f"Invalid type of data items for {modality=}. " f"Expected type: {typ}, but " f"found type: {type(items)}") return items # type: ignore[return-value]
ModalityDataParser: TypeAlias = Callable[[ModalityData[Any]], ModalityDataItems[Any, Any]]
[docs]class MultiModalDataParser: """ Parses :data:`~vllm.multimodal.inputs.MultiModalDataDict` into :class:`MultiModalDataItems`. Args: target_sr (float, optional): Enables automatic resampling of audio items to the model's expected sampling rate. """ def __init__(self, *, target_sr: Optional[float] = None) -> None: super().__init__() self.target_sr = target_sr def _is_embeddings( self, data: object ) -> TypeGuard[Union[torch.Tensor, list[torch.Tensor]]]: if isinstance(data, torch.Tensor): return data.ndim == 3 if is_list_of(data, torch.Tensor): return len(data) == 0 or data[0].ndim == 2 return False def _get_audio_with_sr( self, audio: AudioItem, ) -> tuple[np.ndarray, Optional[float]]: if isinstance(audio, tuple): return audio if isinstance(audio, list): return np.array(audio), None if isinstance(audio, np.ndarray): return audio, None if isinstance(audio, torch.Tensor): return audio.numpy(), None assert_never(audio) def _parse_audio_data( self, data: ModalityData[AudioItem], ) -> ModalityDataItems[Any, Any]: if self._is_embeddings(data): return AudioEmbeddingItems(data) if (is_list_of(data, float) or isinstance(data, (np.ndarray, torch.Tensor)) and data.ndim == 1 or isinstance(data, tuple)): data_items = [data] elif isinstance(data, (np.ndarray, torch.Tensor)): data_items = [elem for elem in data] else: data_items = data new_audios = list[np.ndarray]() for data_item in data_items: audio, orig_sr = self._get_audio_with_sr(data_item) if orig_sr is None: new_audio = audio else: target_sr = self.target_sr if target_sr is None: raise RuntimeError( "Audio resampling is not supported when " "`target_sr` is not provided") new_audio = resample_audio(audio, orig_sr=orig_sr, target_sr=target_sr) new_audios.append(new_audio) return AudioProcessorItems(new_audios) def _parse_image_data( self, data: ModalityData[ImageItem], ) -> ModalityDataItems[Any, Any]: if self._is_embeddings(data): return ImageEmbeddingItems(data) if (isinstance(data, Image) or isinstance(data, (np.ndarray, torch.Tensor)) and data.ndim == 3): data_items = [data] elif isinstance(data, (np.ndarray, torch.Tensor)): data_items = [elem for elem in data] else: data_items = data return ImageProcessorItems(data_items) def _parse_video_data( self, data: ModalityData[VideoItem], ) -> ModalityDataItems[Any, Any]: if self._is_embeddings(data): return VideoEmbeddingItems(data) if (is_list_of(data, Image) or isinstance(data, (np.ndarray, torch.Tensor)) and data.ndim == 4): data_items = [data] elif isinstance(data, (np.ndarray, torch.Tensor)): data_items = [elem for elem in data] else: data_items = data return VideoProcessorItems(data_items) def _get_subparsers(self) -> Mapping[str, ModalityDataParser]: return { "audio": self._parse_audio_data, "image": self._parse_image_data, "video": self._parse_video_data, } def parse_mm_data(self, mm_data: MultiModalDataDict) -> MultiModalDataItems: subparsers = self._get_subparsers() mm_items = MultiModalDataItems() for k, v in mm_data.items(): if k not in subparsers: raise ValueError(f"Unsupported modality: {k}") mm_items[k] = subparsers[k](v) return mm_items