Skip to content

vllm.utils.mem_utils

Classes:

Functions:

MemoryProfilingResult dataclass

Memory profiling result. All numbers are in bytes.

Source code in vllm/utils/mem_utils.py
@dataclass
class MemoryProfilingResult:
    """Memory profiling result. All numbers are in bytes."""

    non_kv_cache_memory: int = 0
    torch_peak_increase: int = 0
    non_torch_increase: int = 0
    weights_memory: int = 0
    before_create: MemorySnapshot = field(default_factory=MemorySnapshot)
    profile_time: float = 0.0

    def __post_init__(self) -> None:
        device = self.before_create.device_

        self.before_profile = MemorySnapshot(device=device, auto_measure=False)
        self.after_profile = MemorySnapshot(device=device, auto_measure=False)

    def __repr__(self) -> str:
        return (
            f"Memory profiling takes {self.profile_time:.2f} seconds. "
            f"Total non KV cache memory: "
            f"{format_gib(self.non_kv_cache_memory)}GiB; "
            f"torch peak memory increase: "
            f"{format_gib(self.torch_peak_increase)}GiB; "
            f"non-torch forward increase memory: "
            f"{format_gib(self.non_torch_increase)}GiB; "
            f"weights memory: {format_gib(self.weights_memory)}GiB."
        )

MemorySnapshot dataclass

Memory snapshot.

Source code in vllm/utils/mem_utils.py
@dataclass
class MemorySnapshot:
    """Memory snapshot."""

    torch_peak: int = 0
    free_memory: int = 0
    total_memory: int = 0
    cuda_memory: int = 0
    torch_memory: int = 0
    non_torch_memory: int = 0
    timestamp: float = 0.0

    device: torch.types.Device = None
    auto_measure: bool = True

    def __post_init__(self) -> None:
        if self.device is None:
            device_fn = current_platform.current_device
            assert device_fn is not None
            self.device_ = torch.device(device_fn())
        else:
            self.device_ = torch.device(self.device)

        if self.auto_measure:
            self.measure()

    def measure(self) -> None:
        device = self.device_

        # we measure the torch peak memory usage via allocated_bytes,
        # rather than `torch.accelerator.memory_reserved()` .
        # After `torch.accelerator.reset_peak_memory_stats()`,
        # `torch.accelerator.memory_reserved()` will keep growing, and only shrink
        # when we call `torch.accelerator.empty_cache()` or OOM happens.
        self.torch_peak = torch.accelerator.memory_stats(device).get(
            "allocated_bytes.all.peak", 0
        )

        self.free_memory, self.total_memory = current_platform.mem_get_info(device)
        if current_platform.is_integrated_gpu(device.index):
            # On UMA (Unified Memory Architecture) platforms where CPU and
            # GPU share physical memory (e.g. GH200, DGX Spark, Jetson Orin),
            # cudaMemGetInfo underreports free memory because it does not
            # account for reclaimable OS memory (page cache, buffers).
            # Use psutil to get the true available memory.
            # https://docs.nvidia.com/cuda/cuda-for-tegra-appnote/#estimating-total-allocatable-device-memory-on-an-integrated-gpu-device
            self.free_memory = psutil.virtual_memory().available

        self.cuda_memory = self.total_memory - self.free_memory

        # torch.accelerator.memory_reserved() is how many bytes
        # PyTorch gets from cuda (by calling cudaMalloc, etc.)
        # this is used to measure the non-torch memory usage
        self.torch_memory = torch.accelerator.memory_reserved(device)

        self.non_torch_memory = self.cuda_memory - self.torch_memory
        self.timestamp = time.time()

    def __sub__(self, other: "MemorySnapshot") -> "MemorySnapshot":
        if self.device_ != other.device_:
            raise ValueError(
                "The two snapshots should be from the same device! "
                f"Found: {self.device_} vs. {other.device_}"
            )

        return MemorySnapshot(
            torch_peak=self.torch_peak - other.torch_peak,
            free_memory=self.free_memory - other.free_memory,
            total_memory=self.total_memory - other.total_memory,
            cuda_memory=self.cuda_memory - other.cuda_memory,
            torch_memory=self.torch_memory - other.torch_memory,
            non_torch_memory=self.non_torch_memory - other.non_torch_memory,
            timestamp=self.timestamp - other.timestamp,
            device=self.device_,
            auto_measure=False,
        )

    def __repr__(self) -> str:
        return (
            f"torch_peak={format_gib(self.torch_peak)}GiB, "
            f"free_memory={format_gib(self.free_memory)}GiB, "
            f"total_memory={format_gib(self.total_memory)}GiB, "
            f"{current_platform.device_name}_memory={format_gib(self.cuda_memory)}GiB, "
            f"torch_memory={format_gib(self.torch_memory)}GiB, "
            f"non_torch_memory={format_gib(self.non_torch_memory)}GiB, "
            f"timestamp={self.timestamp}, "
            f"auto_measure={self.auto_measure}"
        )

get_cpu_memory()

Returns the total CPU memory of the node in bytes.

Source code in vllm/utils/mem_utils.py
def get_cpu_memory() -> int:
    """Returns the total CPU memory of the node in bytes."""
    return psutil.virtual_memory().total

get_max_shared_memory_bytes(gpu=0) cached

Returns the maximum shared memory per thread block in bytes.

Source code in vllm/utils/mem_utils.py
@cache
def get_max_shared_memory_bytes(gpu: int = 0) -> int:
    """Returns the maximum shared memory per thread block in bytes."""
    from vllm import _custom_ops as ops

    max_shared_mem = ops.get_max_shared_memory_per_block_device_attribute(gpu)
    # value 0 will cause MAX_SEQ_LEN become negative and test_attention.py
    # will fail
    assert max_shared_mem > 0, "max_shared_mem cannot be zero"
    return int(max_shared_mem)

memory_profiling(baseline_snapshot, weights_memory=0)

Memory profiling context manager.

baseline_snapshot: the memory snapshot before the current vLLM instance. weights_memory: memory used by PyTorch when loading the model weights. Note that, before loading the model weights, we also initialize the device and distributed environment, which may consume some memory. This part is not included in the weights_memory because PyTorch does not control it.

The memory in one GPU can be classified into 3 categories: 1. memory used by anything other than the current vLLM instance. 2. memory used by torch in the current vLLM instance. 3. memory used in the current vLLM instance, but not by torch.

A quantitive example:

Before creating the current vLLM instance

category 1: 1 GiB category 2: 0 GiB category 3: 0 GiB

After creating the current vLLM instance and loading the model, (i.e. before profiling): category 1: 1 GiB category 2: 2 GiB (model weights take 2 GiB) category 3: 0.5 GiB (memory used by NCCL)

During profiling (peak): category 1: 1 GiB category 2: 4 GiB (peak activation tensors take 2 GiB) category 3: 1 GiB (memory used by NCCL + buffers for some attention backends)

After profiling

category 1: 1 GiB category 2: 3 GiB (after garbage-collecting activation tensors) category 3: 1 GiB (memory used by NCCL + buffers for some attention backends)

In this case, non-kv cache takes 5 GiB in total, including: a. 2 GiB used by the model weights (category 2) b. 2 GiB reserved for the peak activation tensors (category 2) c. 1 GiB used by non-torch components (category 3)

The memory used for loading weights (a.) is directly given from the argument weights_memory.

The increase of torch.accelerator.memory_stats()["allocated_bytes.all.peak"] during profiling gives (b.).

The increase of non_torch_memory from creating the current vLLM instance until after profiling to get (c.).

Source code in vllm/utils/mem_utils.py
@contextlib.contextmanager
def memory_profiling(
    baseline_snapshot: MemorySnapshot,
    weights_memory: int = 0,
) -> Generator[MemoryProfilingResult, None, None]:
    """
    Memory profiling context manager.

    baseline_snapshot: the memory snapshot before the current vLLM instance.
    weights_memory: memory used by PyTorch when loading the model weights.
        Note that, before loading the model weights, we also initialize the device
        and distributed environment, which may consume some memory. This part is not
        included in the weights_memory because PyTorch does not control it.

    The memory in one GPU can be classified into 3 categories:
    1. memory used by anything other than the current vLLM instance.
    2. memory used by torch in the current vLLM instance.
    3. memory used in the current vLLM instance, but not by torch.

    A quantitive example:

    Before creating the current vLLM instance:
        category 1: 1 GiB
        category 2: 0 GiB
        category 3: 0 GiB

    After creating the current vLLM instance and loading the model,
    (i.e. before profiling):
        category 1: 1 GiB
        category 2: 2 GiB (model weights take 2 GiB)
        category 3: 0.5 GiB (memory used by NCCL)

    During profiling (peak):
        category 1: 1 GiB
        category 2: 4 GiB (peak activation tensors take 2 GiB)
        category 3: 1 GiB (memory used by NCCL + buffers for some attention backends)

    After profiling:
        category 1: 1 GiB
        category 2: 3 GiB (after garbage-collecting activation tensors)
        category 3: 1 GiB (memory used by NCCL + buffers for some attention backends)

    In this case, non-kv cache takes 5 GiB in total, including:
    a. 2 GiB used by the model weights (category 2)
    b. 2 GiB reserved for the peak activation tensors (category 2)
    c. 1 GiB used by non-torch components (category 3)

    The memory used for loading weights (a.) is directly given from the
    argument `weights_memory`.

    The increase of `torch.accelerator.memory_stats()["allocated_bytes.all.peak"]`
    during profiling gives (b.).

    The increase of `non_torch_memory` from creating the current vLLM instance
    until after profiling to get (c.).
    """
    gc.collect()
    torch.accelerator.empty_cache()
    torch.accelerator.reset_peak_memory_stats(baseline_snapshot.device_)

    result = MemoryProfilingResult(
        before_create=baseline_snapshot,
        # the part of memory used for holding the model weights
        weights_memory=weights_memory,
    )

    result.before_profile.measure()

    yield result

    gc.collect()
    torch.accelerator.empty_cache()

    result.after_profile.measure()

    diff_profile = result.after_profile - result.before_profile
    diff_from_create = result.after_profile - result.before_create
    result.torch_peak_increase = diff_profile.torch_peak
    result.non_torch_increase = diff_from_create.non_torch_memory
    result.profile_time = diff_profile.timestamp

    non_torch_memory = result.non_torch_increase
    peak_activation_memory = result.torch_peak_increase
    result.non_kv_cache_memory = (
        non_torch_memory + peak_activation_memory + result.weights_memory
    )

release_device_memory_under_pressure(device)

On integrated (UMA) GPUs, release caching-allocator memory back to the OS when system memory pressure is high. The OS may start thrashing before an allocation failure would trigger PyTorch's own cache release.

Returns:

  • bool

    True if memory was released.

Source code in vllm/utils/mem_utils.py
def release_device_memory_under_pressure(device: torch.device) -> bool:
    """On integrated (UMA) GPUs, release caching-allocator memory back to the
    OS when system memory pressure is high. The OS may start thrashing before
    an allocation failure would trigger PyTorch's own cache release.

    Returns:
        True if memory was released.
    """
    if device.type != "cuda" or not current_platform.is_integrated_gpu(device.index):
        return False

    releasable = torch.accelerator.memory_reserved(
        device
    ) - torch.accelerator.memory_allocated(device)
    if releasable < _UMA_MIN_RELEASE_BYTES:
        return False

    # cudaMemGetInfo underreports free memory on UMA, see MemorySnapshot.measure
    mem = psutil.virtual_memory()
    if mem.available > (1 - _UMA_PRESSURE_THRESHOLD) * mem.total:
        return False

    torch.accelerator.synchronize(device)
    torch.accelerator.empty_cache()
    logger.debug(
        "Released %sGiB of cached device memory under memory pressure",
        format_gib(releasable),
    )
    return True