Skip to content

vllm.device_allocator.xpumem

Classes:

XpuMemAllocator

A singleton pluggable allocator helper for XPU.

Note: Sleep will offload selected payloads to CPU or discard and unmap XPU physical memory. Wake-up remaps physical memory back to the same reserved virtual address and restores payload.

Methods:

  • release_pools

    Drop Python references to MemPool/pluggable allocators eagerly.

Source code in vllm/device_allocator/xpumem.py
class XpuMemAllocator:
    """A singleton pluggable allocator helper for XPU.

    Note:
    Sleep will offload selected payloads to CPU or discard and unmap XPU
    physical memory. Wake-up remaps physical memory back to the same
    reserved virtual address and restores payload.
    """

    instance: "XpuMemAllocator | None" = None
    default_tag: str = "default"

    @staticmethod
    def get_instance() -> "XpuMemAllocator":
        assert xpumem_available, "xpumem allocator is not available"
        if XpuMemAllocator.instance is None:
            XpuMemAllocator.instance = XpuMemAllocator()
            # Ensure MemPool/allocator wrappers are released before interpreter
            # finalization tears down XPU runtime internals.
            atexit.register(XpuMemAllocator._shutdown_singleton)
        return XpuMemAllocator.instance

    @staticmethod
    def _shutdown_singleton() -> None:
        instance = XpuMemAllocator.instance
        if instance is None:
            return
        try:
            instance.release_pools()
        except Exception:
            logger.exception("XpuMemAllocator singleton shutdown failed")

    def __init__(self):
        self.pointer_to_data: dict[int, AllocationData] = {}
        self.current_tag: str = XpuMemAllocator.default_tag
        self.allocator_and_pools: dict[str, Any] = {}
        self.python_malloc_callback = self._python_malloc_callback
        self.python_free_callback = self._python_free_callback

    def _python_malloc_callback(self, allocation_handle: HandleType) -> None:
        ptr = allocation_handle[2]
        self.pointer_to_data[ptr] = AllocationData(allocation_handle, self.current_tag)
        logger.debug(
            "Allocated %s bytes for %s at %s",
            allocation_handle[1],
            self.current_tag,
            ptr,
        )

    def _python_free_callback(self, ptr: int) -> HandleType:
        data = self.pointer_to_data.pop(ptr)
        data.cpu_backup_tensor = None
        logger.debug("Freed %s bytes for %s at %s", data.handle[1], data.tag, ptr)
        return data.handle

    def sleep(self, offload_tags: tuple[str, ...] | str | None = None) -> None:
        if offload_tags is None:
            offload_tags = (XpuMemAllocator.default_tag,)
        elif isinstance(offload_tags, str):
            offload_tags = (offload_tags,)

        assert isinstance(offload_tags, tuple)

        total_bytes = 0
        backup_bytes = 0

        for ptr, data in self.pointer_to_data.items():
            size_in_bytes = data.handle[1]
            total_bytes += size_in_bytes
            if data.tag not in offload_tags:
                unmap_and_release(data.handle)
                continue

            backup_bytes += size_in_bytes
            device, _, _, _ = data.handle
            cpu_backup_tensor = torch.empty(
                size_in_bytes,
                dtype=torch.uint8,
                device="cpu",
                pin_memory=is_pin_memory_available(),
            )
            cpu_ptr = cpu_backup_tensor.data_ptr()
            _xpu_memcpy_sync(
                cpu_ptr,
                ptr,
                size_in_bytes,
                MEMCPY_DEVICE_TO_HOST,
                device,
            )
            data.cpu_backup_tensor = cpu_backup_tensor

            unmap_and_release(data.handle)

        logger.info(
            "XpuMemAllocator: sleep freed %.2f GiB memory in total, of which "
            "%.2f GiB is backed up in CPU and the rest %.2f GiB is discarded "
            "directly.",
            total_bytes / 1024**3,
            backup_bytes / 1024**3,
            (total_bytes - backup_bytes) / 1024**3,
        )

        gc.collect()
        xpu_empty_cache = getattr(torch.xpu, "empty_cache", None)
        if callable(xpu_empty_cache):
            xpu_empty_cache()

    def wake_up(self, tags: list[str] | None = None) -> None:
        for ptr, data in self.pointer_to_data.items():
            if tags is not None and data.tag not in tags:
                continue
            create_and_allocate(data.handle)

            cpu_backup_tensor = data.cpu_backup_tensor
            if cpu_backup_tensor is None:
                continue

            device, size_in_bytes, _, _ = data.handle
            _xpu_memcpy_sync(
                ptr,
                cpu_backup_tensor.data_ptr(),
                size_in_bytes,
                MEMCPY_HOST_TO_DEVICE,
                device,
            )
            data.cpu_backup_tensor = None

    def release_pools(self) -> None:
        """Drop Python references to MemPool/pluggable allocators eagerly.

        This prevents pool destruction from being deferred to interpreter
        finalization, which can happen after parts of XPU runtime are already
        torn down.
        """
        if not self.allocator_and_pools:
            return

        # Note: keep allocators alive while MemPool objects are destroyed.
        # MemPool teardown may invoke allocator virtual methods (e.g. raw_delete)
        # when releasing cached blocks. If allocator wrappers are dropped first,
        # C++ can hit "pure virtual method called" during shutdown.
        pool_entries = list(self.allocator_and_pools.values())
        self.allocator_and_pools.clear()

        mem_pools = [entry[0] for entry in pool_entries]
        allocators = [entry[1] for entry in pool_entries]
        pool_entries.clear()

        xpu_sync = getattr(torch.xpu, "synchronize", None)
        if callable(xpu_sync):
            try:
                xpu_sync()
            except Exception:
                logger.debug("torch.xpu.synchronize() failed during release_pools")

        # Phase 1: drop MemPool refs while allocators are still strongly held.
        mem_pools.clear()
        gc.collect()

        # Phase 2: now it is safe to release allocator wrappers.
        allocators.clear()

    @contextmanager
    def use_memory_pool(self, tag: str | None = None):
        if tag is None:
            tag = XpuMemAllocator.default_tag

        old_tag = self.current_tag
        self.current_tag = tag
        try:
            with use_memory_pool_with_allocator(
                self.python_malloc_callback,
                self.python_free_callback,
            ) as data:
                self.allocator_and_pools[tag] = data
                yield
        finally:
            self.current_tag = old_tag

    def get_current_usage(self) -> int:
        total = 0
        for data in self.pointer_to_data.values():
            total += data.handle[1]
        return total

release_pools()

Drop Python references to MemPool/pluggable allocators eagerly.

This prevents pool destruction from being deferred to interpreter finalization, which can happen after parts of XPU runtime are already torn down.

Source code in vllm/device_allocator/xpumem.py
def release_pools(self) -> None:
    """Drop Python references to MemPool/pluggable allocators eagerly.

    This prevents pool destruction from being deferred to interpreter
    finalization, which can happen after parts of XPU runtime are already
    torn down.
    """
    if not self.allocator_and_pools:
        return

    # Note: keep allocators alive while MemPool objects are destroyed.
    # MemPool teardown may invoke allocator virtual methods (e.g. raw_delete)
    # when releasing cached blocks. If allocator wrappers are dropped first,
    # C++ can hit "pure virtual method called" during shutdown.
    pool_entries = list(self.allocator_and_pools.values())
    self.allocator_and_pools.clear()

    mem_pools = [entry[0] for entry in pool_entries]
    allocators = [entry[1] for entry in pool_entries]
    pool_entries.clear()

    xpu_sync = getattr(torch.xpu, "synchronize", None)
    if callable(xpu_sync):
        try:
            xpu_sync()
        except Exception:
            logger.debug("torch.xpu.synchronize() failed during release_pools")

    # Phase 1: drop MemPool refs while allocators are still strongly held.
    mem_pools.clear()
    gc.collect()

    # Phase 2: now it is safe to release allocator wrappers.
    allocators.clear()