Skip to content

vllm.model_executor.models.interfaces_base

T module-attribute

T = TypeVar('T', default=Tensor)

T_co module-attribute

T_co = TypeVar('T_co', default=Tensor, covariant=True)

logger module-attribute

logger = init_logger(__name__)

VllmModel

Bases: Protocol[T_co]

The interface required for all models in vLLM.

Source code in vllm/model_executor/models/interfaces_base.py
@runtime_checkable
class VllmModel(Protocol[T_co]):
    """The interface required for all models in vLLM."""

    def __init__(
        self,
        vllm_config: "VllmConfig",
        prefix: str = "",
    ) -> None:
        ...

    def forward(
        self,
        input_ids: torch.Tensor,
        positions: torch.Tensor,
    ) -> T_co:
        ...

__init__

__init__(vllm_config: VllmConfig, prefix: str = '') -> None
Source code in vllm/model_executor/models/interfaces_base.py
def __init__(
    self,
    vllm_config: "VllmConfig",
    prefix: str = "",
) -> None:
    ...

forward

forward(input_ids: Tensor, positions: Tensor) -> T_co
Source code in vllm/model_executor/models/interfaces_base.py
def forward(
    self,
    input_ids: torch.Tensor,
    positions: torch.Tensor,
) -> T_co:
    ...

VllmModelForPooling

Bases: VllmModel[T], Protocol[T]

The interface required for all pooling models in vLLM.

Source code in vllm/model_executor/models/interfaces_base.py
@runtime_checkable
class VllmModelForPooling(VllmModel[T], Protocol[T]):
    """The interface required for all pooling models in vLLM."""

    def pooler(
        self,
        hidden_states: T,
        pooling_metadata: "PoolingMetadata",
    ) -> "PoolerOutput":
        """Only called on TP rank 0."""
        ...

pooler

pooler(
    hidden_states: T, pooling_metadata: PoolingMetadata
) -> PoolerOutput

Only called on TP rank 0.

Source code in vllm/model_executor/models/interfaces_base.py
def pooler(
    self,
    hidden_states: T,
    pooling_metadata: "PoolingMetadata",
) -> "PoolerOutput":
    """Only called on TP rank 0."""
    ...

VllmModelForTextGeneration

Bases: VllmModel[T], Protocol[T]

The interface required for all generative models in vLLM.

Source code in vllm/model_executor/models/interfaces_base.py
@runtime_checkable
class VllmModelForTextGeneration(VllmModel[T], Protocol[T]):
    """The interface required for all generative models in vLLM."""

    def compute_logits(
        self,
        hidden_states: T,
        sampling_metadata: "SamplingMetadata",
    ) -> Optional[T]:
        """Return `None` if TP rank > 0."""
        ...

compute_logits

compute_logits(
    hidden_states: T, sampling_metadata: SamplingMetadata
) -> Optional[T]

Return None if TP rank > 0.

Source code in vllm/model_executor/models/interfaces_base.py
def compute_logits(
    self,
    hidden_states: T,
    sampling_metadata: "SamplingMetadata",
) -> Optional[T]:
    """Return `None` if TP rank > 0."""
    ...

_check_vllm_model_forward

_check_vllm_model_forward(
    model: Union[type[object], object],
) -> bool
Source code in vllm/model_executor/models/interfaces_base.py
def _check_vllm_model_forward(model: Union[type[object], object]) -> bool:
    model_forward = getattr(model, "forward", None)
    if not callable(model_forward):
        return False

    vllm_kws = ("input_ids", "positions")
    missing_kws = tuple(kw for kw in vllm_kws
                        if not supports_kw(model_forward, kw))

    if missing_kws and (isinstance(model, type)
                        and issubclass(model, nn.Module)):
        logger.warning(
            "The model (%s) is missing "
            "vLLM-specific keywords from its `forward` method: %s",
            model,
            missing_kws,
        )

    return len(missing_kws) == 0

_check_vllm_model_init

_check_vllm_model_init(
    model: Union[type[object], object],
) -> bool
Source code in vllm/model_executor/models/interfaces_base.py
def _check_vllm_model_init(model: Union[type[object], object]) -> bool:
    model_init = model.__init__
    return supports_kw(model_init, "vllm_config")

is_pooling_model

is_pooling_model(
    model: type[object],
) -> TypeIs[type[VllmModelForPooling]]
is_pooling_model(
    model: object,
) -> TypeIs[VllmModelForPooling]
Source code in vllm/model_executor/models/interfaces_base.py
def is_pooling_model(
    model: Union[type[object], object],
) -> Union[TypeIs[type[VllmModelForPooling]], TypeIs[VllmModelForPooling]]:
    if not is_vllm_model(model):
        return False

    if isinstance(model, type):
        return isinstance(model, VllmModelForPooling)

    return isinstance(model, VllmModelForPooling)

is_text_generation_model

is_text_generation_model(
    model: type[object],
) -> TypeIs[type[VllmModelForTextGeneration]]
is_text_generation_model(
    model: object,
) -> TypeIs[VllmModelForTextGeneration]
Source code in vllm/model_executor/models/interfaces_base.py
def is_text_generation_model(
    model: Union[type[object], object],
) -> Union[TypeIs[type[VllmModelForTextGeneration]],
           TypeIs[VllmModelForTextGeneration]]:
    if not is_vllm_model(model):
        return False

    if isinstance(model, type):
        return isinstance(model, VllmModelForTextGeneration)

    return isinstance(model, VllmModelForTextGeneration)

is_vllm_model

is_vllm_model(
    model: type[object],
) -> TypeIs[type[VllmModel]]
is_vllm_model(model: object) -> TypeIs[VllmModel]
is_vllm_model(
    model: Union[type[object], object],
) -> Union[TypeIs[type[VllmModel]], TypeIs[VllmModel]]
Source code in vllm/model_executor/models/interfaces_base.py
def is_vllm_model(
    model: Union[type[object], object],
) -> Union[TypeIs[type[VllmModel]], TypeIs[VllmModel]]:
    return _check_vllm_model_init(model) and _check_vllm_model_forward(model)