Skip to content

vllm.model_executor

Modules:

Name Description
custom_op
guided_decoding
layers
model_loader
models
parameter
pooling_metadata
sampling_metadata
utils

Utils for model executor.

__all__ module-attribute

__all__ = [
    "SamplingMetadata",
    "SamplingMetadataCache",
    "set_random_seed",
    "BasevLLMParameter",
    "PackedvLLMParameter",
]

BasevLLMParameter

Bases: Parameter

Base parameter for vLLM linear layers. Extends the torch.nn.parameter by taking in a linear weight loader. Will copy the loaded weight into the parameter when the provided weight loader is called.

Source code in vllm/model_executor/parameter.py
class BasevLLMParameter(Parameter):
    """
    Base parameter for vLLM linear layers. Extends the torch.nn.parameter
    by taking in a linear weight loader. Will copy the loaded weight
    into the parameter when the provided weight loader is called.
    """

    def __new__(cls, data: torch.Tensor, **kwargs):

        return super().__new__(cls, data=data, requires_grad=False)

    def __init__(self, data: torch.Tensor, weight_loader: Callable):
        """
        Initialize the BasevLLMParameter

        :param data: torch tensor with the parameter data
        :param weight_loader: weight loader callable

        :returns: a torch.nn.parameter
        """

        # During weight loading, we often do something like:
        # narrowed_tensor = param.data.narrow(0, offset, len)
        # narrowed_tensor.copy_(real_weight)
        # expecting narrowed_tensor and param.data to share the same storage.
        # However, on TPUs, narrowed_tensor will lazily propagate to the base
        # tensor, which is param.data, leading to the redundant memory usage.
        # This sometimes causes OOM errors during model loading. To avoid this,
        # we sync the param tensor after its weight loader is called.
        from vllm.platforms import current_platform
        if current_platform.is_tpu():
            weight_loader = _make_synced_weight_loader(weight_loader)

        self._weight_loader = weight_loader

    @property
    def weight_loader(self):
        return self._weight_loader

    def _is_1d_and_scalar(self, loaded_weight: torch.Tensor):
        cond1 = self.data.ndim == 1 and self.data.numel() == 1
        cond2 = loaded_weight.ndim == 0 and loaded_weight.numel() == 1
        return (cond1 and cond2)

    def _assert_and_load(self, loaded_weight: torch.Tensor):
        assert (self.data.shape == loaded_weight.shape
                or self._is_1d_and_scalar(loaded_weight))
        self.data.copy_(loaded_weight)

    def load_column_parallel_weight(self, loaded_weight: torch.Tensor):
        self._assert_and_load(loaded_weight)

    def load_row_parallel_weight(self, loaded_weight: torch.Tensor):
        self._assert_and_load(loaded_weight)

    def load_merged_column_weight(self, loaded_weight: torch.Tensor, **kwargs):
        self._assert_and_load(loaded_weight)

    def load_qkv_weight(self, loaded_weight: torch.Tensor, **kwargs):
        self._assert_and_load(loaded_weight)

_weight_loader instance-attribute

_weight_loader = weight_loader

weight_loader property

weight_loader

__init__

__init__(data: Tensor, weight_loader: Callable)

Initialize the BasevLLMParameter

:param data: torch tensor with the parameter data :param weight_loader: weight loader callable

:returns: a torch.nn.parameter

Source code in vllm/model_executor/parameter.py
def __init__(self, data: torch.Tensor, weight_loader: Callable):
    """
    Initialize the BasevLLMParameter

    :param data: torch tensor with the parameter data
    :param weight_loader: weight loader callable

    :returns: a torch.nn.parameter
    """

    # During weight loading, we often do something like:
    # narrowed_tensor = param.data.narrow(0, offset, len)
    # narrowed_tensor.copy_(real_weight)
    # expecting narrowed_tensor and param.data to share the same storage.
    # However, on TPUs, narrowed_tensor will lazily propagate to the base
    # tensor, which is param.data, leading to the redundant memory usage.
    # This sometimes causes OOM errors during model loading. To avoid this,
    # we sync the param tensor after its weight loader is called.
    from vllm.platforms import current_platform
    if current_platform.is_tpu():
        weight_loader = _make_synced_weight_loader(weight_loader)

    self._weight_loader = weight_loader

__new__

__new__(data: Tensor, **kwargs)
Source code in vllm/model_executor/parameter.py
def __new__(cls, data: torch.Tensor, **kwargs):

    return super().__new__(cls, data=data, requires_grad=False)

_assert_and_load

_assert_and_load(loaded_weight: Tensor)
Source code in vllm/model_executor/parameter.py
def _assert_and_load(self, loaded_weight: torch.Tensor):
    assert (self.data.shape == loaded_weight.shape
            or self._is_1d_and_scalar(loaded_weight))
    self.data.copy_(loaded_weight)

_is_1d_and_scalar

_is_1d_and_scalar(loaded_weight: Tensor)
Source code in vllm/model_executor/parameter.py
def _is_1d_and_scalar(self, loaded_weight: torch.Tensor):
    cond1 = self.data.ndim == 1 and self.data.numel() == 1
    cond2 = loaded_weight.ndim == 0 and loaded_weight.numel() == 1
    return (cond1 and cond2)

load_column_parallel_weight

load_column_parallel_weight(loaded_weight: Tensor)
Source code in vllm/model_executor/parameter.py
def load_column_parallel_weight(self, loaded_weight: torch.Tensor):
    self._assert_and_load(loaded_weight)

load_merged_column_weight

load_merged_column_weight(loaded_weight: Tensor, **kwargs)
Source code in vllm/model_executor/parameter.py
def load_merged_column_weight(self, loaded_weight: torch.Tensor, **kwargs):
    self._assert_and_load(loaded_weight)

load_qkv_weight

load_qkv_weight(loaded_weight: Tensor, **kwargs)
Source code in vllm/model_executor/parameter.py
def load_qkv_weight(self, loaded_weight: torch.Tensor, **kwargs):
    self._assert_and_load(loaded_weight)

load_row_parallel_weight

load_row_parallel_weight(loaded_weight: Tensor)
Source code in vllm/model_executor/parameter.py
def load_row_parallel_weight(self, loaded_weight: torch.Tensor):
    self._assert_and_load(loaded_weight)

PackedvLLMParameter

Bases: ModelWeightParameter

Parameter for model weights which are packed on disk. Example: GPTQ Marlin weights are int4 or int8, packed into int32. Extends the ModelWeightParameter to take in the packed factor, the packed dimension, and optionally, marlin tile size for marlin kernels. Adjusts the shard_size and shard_offset for fused linear layers model weight loading by accounting for packing and optionally, marlin tile size.

Source code in vllm/model_executor/parameter.py
class PackedvLLMParameter(ModelWeightParameter):
    """
    Parameter for model weights which are packed on disk.
    Example: GPTQ Marlin weights are int4 or int8, packed into int32.
    Extends the ModelWeightParameter to take in the
    packed factor, the packed dimension, and optionally, marlin
    tile size for marlin kernels. Adjusts the shard_size and 
    shard_offset for fused linear layers model weight loading
    by accounting for packing and optionally, marlin tile size.
    """

    def __init__(self,
                 packed_factor: Union[int, Fraction],
                 packed_dim: int,
                 marlin_tile_size: Optional[int] = None,
                 bitblas_tile_size: Optional[int] = None,
                 **kwargs):
        self._packed_factor = packed_factor
        self._packed_dim = packed_dim
        self._marlin_tile_size = marlin_tile_size
        self._bitblas_tile_size = bitblas_tile_size
        super().__init__(**kwargs)

    @property
    def packed_dim(self):
        return self._packed_dim

    @property
    def packed_factor(self):
        return self._packed_factor

    @property
    def marlin_tile_size(self):
        return self._marlin_tile_size

    @property
    def bitblas_tile_size(self):
        return self._bitblas_tile_size

    def adjust_shard_indexes_for_packing(self, shard_size, shard_offset):
        return _adjust_shard_indexes_for_packing(
            shard_size=shard_size,
            shard_offset=shard_offset,
            packed_factor=self.packed_factor,
            marlin_tile_size=self.marlin_tile_size,
            bitblas_tile_size=self.bitblas_tile_size)

_bitblas_tile_size instance-attribute

_bitblas_tile_size = bitblas_tile_size

_marlin_tile_size instance-attribute

_marlin_tile_size = marlin_tile_size

_packed_dim instance-attribute

_packed_dim = packed_dim

_packed_factor instance-attribute

_packed_factor = packed_factor

bitblas_tile_size property

bitblas_tile_size

marlin_tile_size property

marlin_tile_size

packed_dim property

packed_dim

packed_factor property

packed_factor

__init__

__init__(
    packed_factor: Union[int, Fraction],
    packed_dim: int,
    marlin_tile_size: Optional[int] = None,
    bitblas_tile_size: Optional[int] = None,
    **kwargs,
)
Source code in vllm/model_executor/parameter.py
def __init__(self,
             packed_factor: Union[int, Fraction],
             packed_dim: int,
             marlin_tile_size: Optional[int] = None,
             bitblas_tile_size: Optional[int] = None,
             **kwargs):
    self._packed_factor = packed_factor
    self._packed_dim = packed_dim
    self._marlin_tile_size = marlin_tile_size
    self._bitblas_tile_size = bitblas_tile_size
    super().__init__(**kwargs)

adjust_shard_indexes_for_packing

adjust_shard_indexes_for_packing(shard_size, shard_offset)
Source code in vllm/model_executor/parameter.py
def adjust_shard_indexes_for_packing(self, shard_size, shard_offset):
    return _adjust_shard_indexes_for_packing(
        shard_size=shard_size,
        shard_offset=shard_offset,
        packed_factor=self.packed_factor,
        marlin_tile_size=self.marlin_tile_size,
        bitblas_tile_size=self.bitblas_tile_size)

SamplingMetadata

Metadata for input sequences. Used in sampler.

The usage is as follow;

hidden_states = execute_model(...)
logits = hidden_states[sampling_metadata.selected_token_indices]
sample(logits)

def sample(logits):
    # Use categorized_sample_indices for sampling....

Parameters:

Name Type Description Default
seq_groups list[SequenceGroupToSample]

List of batched sequence groups.

required
selected_token_indices Tensor

(num_query_tokens_to_logprob). Indices to find logits from the initial model output hidden states.

required
categorized_sample_indices dict[SamplingType, Tensor]

SamplingType -> token indices to sample. Each token indices is 2D tensor of (num_indices, num_indices) where the first item means the sample index within the returned logit (before pruning padding), and the second item means the sample index after pruning using selected_token_indices. For example, if the returned logit is [1, 2, 3], and we select [1, 2] for sampling, the pruned logit will be [2, 3]. In this case, The first tuple is [1, 2] (sampled index within original logit), and the second tuple is [0, 1] (sampled index within pruned logit).

required
num_prompts int

Number of prompt sequence groups in seq_groups.

required
skip_sampler_cpu_output bool

Indicates if we want to skip the GPU=>CPU serialization of token outputs.

False
reuse_sampling_tensors bool

Indicates if we want to reuse sampling tensors that are part of the sampler forward pass. Currently, it is mainly used for multi-step decode.

False
Source code in vllm/model_executor/sampling_metadata.py
class SamplingMetadata:
    """Metadata for input sequences. Used in sampler.

    The usage is as follow;
    ```
    hidden_states = execute_model(...)
    logits = hidden_states[sampling_metadata.selected_token_indices]
    sample(logits)

    def sample(logits):
        # Use categorized_sample_indices for sampling....
    ```

    Args:
        seq_groups: List of batched sequence groups.
        selected_token_indices: (num_query_tokens_to_logprob). Indices to find
            logits from the initial model output hidden states.
        categorized_sample_indices: SamplingType -> token indices to sample.
            Each token indices is 2D tensor of (num_indices, num_indices) where
            the first item means the sample index within the returned logit
            (before pruning padding), and the second item means the sample
            index after pruning using selected_token_indices.
            For example, if the returned logit is [1, 2, 3], and we select
            [1, 2] for sampling, the pruned logit will be [2, 3]. In this case,
            The first tuple is [1, 2] (sampled index within original logit),
            and the second tuple is [0, 1] (sampled index within pruned logit).
        num_prompts: Number of prompt sequence groups in seq_groups.
        skip_sampler_cpu_output: Indicates if we want to skip the GPU=>CPU
            serialization of token outputs.
        reuse_sampling_tensors: Indicates if we want to reuse sampling
            tensors that are part of the sampler forward pass. Currently,
            it is mainly used for multi-step decode.

    """

    def __init__(
        self,
        seq_groups: list[SequenceGroupToSample],
        selected_token_indices: torch.Tensor,
        categorized_sample_indices: dict[SamplingType, torch.Tensor],
        num_prompts: int,
        skip_sampler_cpu_output: bool = False,
        reuse_sampling_tensors: bool = False,
    ) -> None:
        self.seq_groups = seq_groups
        self.selected_token_indices = selected_token_indices
        self.categorized_sample_indices = categorized_sample_indices
        self.num_prompts = num_prompts
        self.skip_sampler_cpu_output = skip_sampler_cpu_output
        self.reuse_sampling_tensors = reuse_sampling_tensors

    @staticmethod
    def prepare(
        seq_group_metadata_list: list[SequenceGroupMetadata],
        seq_lens: list[int],
        query_lens: list[int],
        device: str,
        pin_memory: bool,
        generators: Optional[dict[str, torch.Generator]] = None,
        cache: Optional[SamplingMetadataCache] = None,
    ) -> "SamplingMetadata":
        (
            seq_groups,
            selected_token_indices,
            categorized_sample_indices,
            num_prompts,
        ) = _prepare_seq_groups(seq_group_metadata_list, seq_lens, query_lens,
                                device, generators, cache)
        selected_token_indices = async_tensor_h2d(
            selected_token_indices,
            dtype=torch.long,
            target_device=device,
            pin_memory=pin_memory,
        )
        categorized_sample_indices = {
            t:
            async_tensor_h2d(
                seq_ids,
                dtype=torch.int,
                target_device=device,
                pin_memory=pin_memory,
            )
            for t, seq_ids in categorized_sample_indices.items()
        }

        sampling_metadata = SamplingMetadata(
            seq_groups=seq_groups,
            selected_token_indices=selected_token_indices,
            categorized_sample_indices=categorized_sample_indices,
            num_prompts=num_prompts,
        )
        return sampling_metadata

    def __repr__(self) -> str:
        return (
            "SamplingMetadata("
            f"seq_groups={self.seq_groups}, "
            f"selected_token_indices={self.selected_token_indices}, "
            f"categorized_sample_indices={self.categorized_sample_indices})")

categorized_sample_indices instance-attribute

categorized_sample_indices = categorized_sample_indices

num_prompts instance-attribute

num_prompts = num_prompts

reuse_sampling_tensors instance-attribute

reuse_sampling_tensors = reuse_sampling_tensors

selected_token_indices instance-attribute

selected_token_indices = selected_token_indices

seq_groups instance-attribute

seq_groups = seq_groups

skip_sampler_cpu_output instance-attribute

skip_sampler_cpu_output = skip_sampler_cpu_output

__init__

__init__(
    seq_groups: list[SequenceGroupToSample],
    selected_token_indices: Tensor,
    categorized_sample_indices: dict[SamplingType, Tensor],
    num_prompts: int,
    skip_sampler_cpu_output: bool = False,
    reuse_sampling_tensors: bool = False,
) -> None
Source code in vllm/model_executor/sampling_metadata.py
def __init__(
    self,
    seq_groups: list[SequenceGroupToSample],
    selected_token_indices: torch.Tensor,
    categorized_sample_indices: dict[SamplingType, torch.Tensor],
    num_prompts: int,
    skip_sampler_cpu_output: bool = False,
    reuse_sampling_tensors: bool = False,
) -> None:
    self.seq_groups = seq_groups
    self.selected_token_indices = selected_token_indices
    self.categorized_sample_indices = categorized_sample_indices
    self.num_prompts = num_prompts
    self.skip_sampler_cpu_output = skip_sampler_cpu_output
    self.reuse_sampling_tensors = reuse_sampling_tensors

__repr__

__repr__() -> str
Source code in vllm/model_executor/sampling_metadata.py
def __repr__(self) -> str:
    return (
        "SamplingMetadata("
        f"seq_groups={self.seq_groups}, "
        f"selected_token_indices={self.selected_token_indices}, "
        f"categorized_sample_indices={self.categorized_sample_indices})")

prepare staticmethod

prepare(
    seq_group_metadata_list: list[SequenceGroupMetadata],
    seq_lens: list[int],
    query_lens: list[int],
    device: str,
    pin_memory: bool,
    generators: Optional[dict[str, Generator]] = None,
    cache: Optional[SamplingMetadataCache] = None,
) -> SamplingMetadata
Source code in vllm/model_executor/sampling_metadata.py
@staticmethod
def prepare(
    seq_group_metadata_list: list[SequenceGroupMetadata],
    seq_lens: list[int],
    query_lens: list[int],
    device: str,
    pin_memory: bool,
    generators: Optional[dict[str, torch.Generator]] = None,
    cache: Optional[SamplingMetadataCache] = None,
) -> "SamplingMetadata":
    (
        seq_groups,
        selected_token_indices,
        categorized_sample_indices,
        num_prompts,
    ) = _prepare_seq_groups(seq_group_metadata_list, seq_lens, query_lens,
                            device, generators, cache)
    selected_token_indices = async_tensor_h2d(
        selected_token_indices,
        dtype=torch.long,
        target_device=device,
        pin_memory=pin_memory,
    )
    categorized_sample_indices = {
        t:
        async_tensor_h2d(
            seq_ids,
            dtype=torch.int,
            target_device=device,
            pin_memory=pin_memory,
        )
        for t, seq_ids in categorized_sample_indices.items()
    }

    sampling_metadata = SamplingMetadata(
        seq_groups=seq_groups,
        selected_token_indices=selected_token_indices,
        categorized_sample_indices=categorized_sample_indices,
        num_prompts=num_prompts,
    )
    return sampling_metadata

SamplingMetadataCache

Used to cache SamplingMetadata objects between scheduler iterations

Source code in vllm/model_executor/sampling_metadata.py
class SamplingMetadataCache:
    """Used to cache SamplingMetadata objects between scheduler iterations"""

    def __init__(self):
        self._seq_group_to_sample_cache: dict[int, PyObjectCache] = {}

    def get_cached_seq_group_to_sample(self, num_seqs):
        if num_seqs not in self._seq_group_to_sample_cache:
            self._seq_group_to_sample_cache[num_seqs] = PyObjectCache(
                gen_seq_group_to_sample_builder(num_seqs))

        obj = self._seq_group_to_sample_cache[num_seqs].get_object()
        return obj

    def reset(self):
        for cache in self._seq_group_to_sample_cache.values():
            cache.reset()

_seq_group_to_sample_cache instance-attribute

_seq_group_to_sample_cache: dict[int, PyObjectCache] = {}

__init__

__init__()
Source code in vllm/model_executor/sampling_metadata.py
def __init__(self):
    self._seq_group_to_sample_cache: dict[int, PyObjectCache] = {}

get_cached_seq_group_to_sample

get_cached_seq_group_to_sample(num_seqs)
Source code in vllm/model_executor/sampling_metadata.py
def get_cached_seq_group_to_sample(self, num_seqs):
    if num_seqs not in self._seq_group_to_sample_cache:
        self._seq_group_to_sample_cache[num_seqs] = PyObjectCache(
            gen_seq_group_to_sample_builder(num_seqs))

    obj = self._seq_group_to_sample_cache[num_seqs].get_object()
    return obj

reset

reset()
Source code in vllm/model_executor/sampling_metadata.py
def reset(self):
    for cache in self._seq_group_to_sample_cache.values():
        cache.reset()

set_random_seed

set_random_seed(seed: int) -> None
Source code in vllm/model_executor/utils.py
def set_random_seed(seed: int) -> None:
    from vllm.platforms import current_platform
    current_platform.seed_everything(seed)