Skip to content

llmcompressor.pipelines.cache

Classes:

  • IntermediateValue

    Dataclass which recursively defines offloaded values and which device to onload to

  • IntermediatesCache

    Cache which stores intermediate values (activations) produced by batched, sequential

  • OverrideEqMode

    When using a torch.Tensor as a key in a dictionary, the equality

IntermediateValue dataclass

IntermediateValue(
    value: Tensor | "IntermediateValue" | Any,
    device: device | None,
)

Dataclass which recursively defines offloaded values and which device to onload to

Parameters:

  • value (Tensor | 'IntermediateValue' | Any) –

    either an offloaded Tensor, an primative value, or a recursable value

  • device (device | None) –

    if the value is a Tensor, then the device to onload the tensor to, otherwise None

IntermediatesCache

IntermediatesCache(
    batch_intermediates: list[IntermediateValues]
    | None = None,
    offload_device: device | None = "cpu",
)

Cache which stores intermediate values (activations) produced by batched, sequential execution of models. Values are offloaded to the offload_device when stored in the cache and onloaded to their original device when fetched from the cache. If offload_device is None, values will not be offloaded at all.

Currently supports nested offloading of dataclass instances and tuples

Construct using empty and from_dataloader class methods

Methods:

  • append

    Append new values to the cache. The new values will be assigned the next

  • delete

    Delete values from the cache

  • empty

    Construct an empty cache

  • fetch

    Fetch values belonging to a batch

  • from_dataloader

    Initialize a cache with data from the provided dataloader

  • iter_prefetch

    Iterate over batches with the next batch prefetched in a background thread.

  • size

    Returns the memory used by cached values, keyed by device, in bytes

  • update

    Update/put values belonging to a batch

Source code in src/llmcompressor/pipelines/cache.py
def __init__(
    self,
    batch_intermediates: list[IntermediateValues] | None = None,
    offload_device: torch.device | None = "cpu",
):
    self.batch_intermediates = batch_intermediates or []
    self.offload_device = offload_device

append

append(values: dict[str, Any])

Append new values to the cache. The new values will be assigned the next available batch index

Parameters:

  • values (dict[str, Any]) –

    dictionary mapping keys to values used for update

Source code in src/llmcompressor/pipelines/cache.py
def append(self, values: dict[str, Any]):
    """
    Append new values to the cache. The new values will be assigned the next
    available batch index

    :param values: dictionary mapping keys to values used for update
    """
    batch_index = len(self.batch_intermediates)
    self.batch_intermediates.append({})
    self.update(batch_index, values)

delete

delete(
    batch_index: int,
    consumed_names: list[str] | None = None,
)

Delete values from the cache

Parameters:

  • batch_index (int) –

    index of batch whose values will be deleted

  • consumed_names (list[str] | None, default: None ) –

    list of keys whose values will be deleted, defaults to removing all keys

Source code in src/llmcompressor/pipelines/cache.py
def delete(self, batch_index: int, consumed_names: list[str] | None = None):
    """
    Delete values from the cache

    :param batch_index: index of batch whose values will be deleted
    :param consumed_names: list of keys whose values will be deleted, defaults to
        removing all keys
    """
    intermediates = self.batch_intermediates[batch_index]

    if consumed_names is None:
        consumed_names = list(intermediates.keys())

    for name in consumed_names:
        del intermediates[name]

empty classmethod

empty(num_batches: int, offload_device: device)

Construct an empty cache

Parameters:

  • num_batches (int) –

    the expected number of batches to be stored

  • offload_device (device) –

    device to offload values to

Source code in src/llmcompressor/pipelines/cache.py
@classmethod
def empty(cls, num_batches: int, offload_device: torch.device):
    """
    Construct an empty cache

    :param num_batches: the expected number of batches to be stored
    :param offload_device: device to offload values to
    """
    batch_intermediates = [{} for _ in range(num_batches)]
    return cls(batch_intermediates, offload_device)

fetch

fetch(
    batch_index: int, input_names: list[str] | None = None
) -> dict[str, Any]

Fetch values belonging to a batch

Parameters:

  • batch_index (int) –

    index of batch whose values are being fetched

  • input_names (list[str] | None, default: None ) –

    list of keys whose values are being fetched

Returns:

  • dict[str, Any]

    dictionary mapping keys to onloaded values

Source code in src/llmcompressor/pipelines/cache.py
def fetch(
    self, batch_index: int, input_names: list[str] | None = None
) -> dict[str, Any]:
    """
    Fetch values belonging to a batch

    :param batch_index: index of batch whose values are being fetched
    :param input_names: list of keys whose values are being fetched
    :return: dictionary mapping keys to onloaded values
    """
    intermediates = self.batch_intermediates[batch_index]

    return {
        key: self._onload_value(subgraph_input)
        for key, subgraph_input in intermediates.items()
        if input_names is None or key in input_names
    }

from_dataloader classmethod

from_dataloader(
    dataloader: DataLoader,
    model_device: device = torch.device("cpu"),
    offload_device: device | None = torch.device("cpu"),
)

Initialize a cache with data from the provided dataloader

This method iterates through all batches in the dataloader and offloads them to the specified device. For faster cache preparation, consider: - Increasing batch_size to reduce the number of iterations - Using num_workers > 0 in the DataLoader for parallel loading (e.g. the calibration DataLoader from format_calibration_data uses dataloader_num_workers; when > 0, pin_memory and prefetch_factor are also set where applicable, which speeds both cache build and calibration) - Ensuring data preprocessing is done before creating the dataloader

Parameters:

  • dataloader (DataLoader) –

    dataloader which generates values to be cached

  • model_device (device, default: device('cpu') ) –

    device which values will be onloaded to when fetched

  • offload_device (device | None, default: device('cpu') ) –

    device to offload values to

Source code in src/llmcompressor/pipelines/cache.py
@classmethod
def from_dataloader(
    cls,
    dataloader: torch.utils.data.DataLoader,
    model_device: torch.device = torch.device("cpu"),
    offload_device: torch.device | None = torch.device("cpu"),
):
    """
    Initialize a cache with data from the provided dataloader

    This method iterates through all batches in the dataloader and offloads
    them to the specified device. For faster cache preparation, consider:
    - Increasing batch_size to reduce the number of iterations
    - Using num_workers > 0 in the DataLoader for parallel loading (e.g. the
      calibration DataLoader from format_calibration_data uses
      dataloader_num_workers; when > 0, pin_memory and prefetch_factor are
      also set where applicable, which speeds both cache build and calibration)
    - Ensuring data preprocessing is done before creating the dataloader

    :param dataloader: dataloader which generates values to be cached
    :param model_device: device which values will be onloaded to when fetched
    :param offload_device: device to offload values to
    """
    batch_intermediates = [
        {
            key: cls._offload_value(value, offload_device, model_device)
            for key, value in batch.items()
        }
        for batch in tqdm(dataloader, desc="Preparing cache")
    ]

    return cls(batch_intermediates, offload_device)

iter_prefetch

iter_prefetch(
    input_names: list[str] | None = None,
) -> Generator[Any, None, None]

Iterate over batches with the next batch prefetched in a background thread. Overlaps onload from offload_device with consumption of the current batch, which can reduce wall-clock time when offloading to CPU.

When CUDA is available, uses non_blocking transfers (requires pinned CPU tensors, set up by _offload_value) and synchronises via CUDA events so the main stream waits for each H2D copy before running GPU kernels on the data.

Yields the same fetched batch dicts as :meth:iter; only the timing of onloads differs.

Source code in src/llmcompressor/pipelines/cache.py
def iter_prefetch(
    self, input_names: list[str] | None = None
) -> Generator[Any, None, None]:
    """
    Iterate over batches with the next batch prefetched in a background thread.
    Overlaps onload from offload_device with consumption of the current batch,
    which can reduce wall-clock time when offloading to CPU.

    When CUDA is available, uses non_blocking transfers (requires pinned CPU
    tensors, set up by _offload_value) and synchronises via CUDA events so the
    main stream waits for each H2D copy before running GPU kernels on the data.

    Yields the same fetched batch dicts as :meth:`iter`; only the timing
    of onloads differs.
    """
    num_batches = len(self.batch_intermediates)
    if num_batches == 0:
        return

    # Create a dedicated CUDA stream for H2D transfers so they run on a
    # separate stream from the main thread's compute stream. Without this,
    # both threads default to the null stream (stream 0) which serializes
    # all operations and prevents any overlap.
    h2d_stream = torch.cuda.Stream() if torch.cuda.is_available() else None

    def _fetch_and_record(batch_index):
        event = None
        if h2d_stream is not None:
            with torch.cuda.stream(h2d_stream):
                data = self.fetch(batch_index, input_names)
            event = torch.cuda.Event()
            event.record(h2d_stream)
        else:
            data = self.fetch(batch_index, input_names)
        return data, event

    with ThreadPoolExecutor(max_workers=1) as executor:
        future = None
        for batch_index in range(num_batches):
            if future is not None:
                current, event = future.result()
            else:
                current, event = _fetch_and_record(batch_index)
            if batch_index + 1 < num_batches:
                future = executor.submit(_fetch_and_record, batch_index + 1)
            else:
                future = None
            # Make the main CUDA stream wait for the background H2D copy
            # before any GPU kernel consumes the prefetched tensors
            if event is not None:
                torch.cuda.current_stream().wait_event(event)
            yield current

size

size() -> dict[torch.device, int]

Returns the memory used by cached values, keyed by device, in bytes

Returns:

  • dict[device, int]

    dictionary mapping torch device to number of bytes in cache

Source code in src/llmcompressor/pipelines/cache.py
def size(self) -> dict[torch.device, int]:
    """
    Returns the memory used by cached values, keyed by device, in bytes

    :return: dictionary mapping torch device to number of bytes in cache
    """
    sizes = defaultdict(lambda: 0)
    memo = set()

    def _size_helper(intermediate: IntermediateValue) -> int:
        value = intermediate.value

        match value:
            case torch.Tensor():
                if value not in memo:
                    sizes[value.device] += value.nbytes
                memo.add(value)
            case list() | tuple():
                for v in value:
                    _size_helper(v)
            case dict():
                for v in value.values():
                    _size_helper(v)
            case _ if is_dataclass(value):
                for field in fields(value):
                    _size_helper(getattr(value, field.name))
            case _:
                # this handles primitive values that don't match any other cases
                sizes[torch.device("cpu")] += sys.getsizeof(value, 0)

    for intermediates in self.batch_intermediates:
        for value in intermediates.values():
            _size_helper(value)

    return dict(sizes)

update

update(batch_index: int, values: dict[str, Any])

Update/put values belonging to a batch

Parameters:

  • batch_index (int) –

    index of batch whose values will be updated

  • values (dict[str, Any]) –

    dictionary mapping keys to values used for update

Source code in src/llmcompressor/pipelines/cache.py
def update(self, batch_index: int, values: dict[str, Any]):
    """
    Update/put values belonging to a batch

    :param batch_index: index of batch whose values will be updated
    :param values: dictionary mapping keys to values used for update
    """
    device = self.offload_device
    intermediates = {k: self._offload_value(v, device) for k, v in values.items()}
    self.batch_intermediates[batch_index].update(intermediates)

OverrideEqMode

Bases: TorchDispatchMode

When using a torch.Tensor as a key in a dictionary, the equality check must return a single value instead of a torch.Tensor of bool values. Use this override context for such cases, to swap out the torch.eq equality check for a check on id

a = torch.tensor([1,2,3]) b = torch.tensor([1,2,3]) a == b tensor([True, True, True]) with OverrideEqMode(): ... a == b tensor(True)