Skip to content

llmcompressor.observers

Framework for monitoring and analyzing model behavior during compression.

Provides observers for tracking tensor statistics, activation ranges, and model behavior during compression workflows. Includes min-max observers, MSE observers, and helper utilities for quantization and other compression techniques.

Modules:

Classes:

  • IMatrixMSEObserver

    MSE observer weighted by per-input-channel importance (E[x²]).

  • MemorylessMinMaxObserver

    Compute quantization parameters by taking the min/max of the observed value.

  • MinMaxObserver

    Compute quantization parameters by taking the moving average of min/max values.

  • MovingAverageMSEObserver

    Compute quantization parameters by finding the optimal min/max values which minimize

  • Observer

    Base class for observers which compute quantization parameters given

  • QParamsDict

    Dictionary containing quantization parameters.

  • StaticMinMaxObserver

    Compute quantization parameters by taking the min/max of all observed values.

Functions:

IMatrixMSEObserver

IMatrixMSEObserver(*args, **kwargs)

Bases: Observer

MSE observer weighted by per-input-channel importance (E[x²]).

Supports CHANNEL, GROUP, and TENSOR_GROUP for weight-only Linear modules. Falls back to uniform MSE when importance data is unavailable.

Importance is accumulated as raw _imatrix_sum / _imatrix_count and synced across DDP ranks via _act_sync_dict before observation.

Methods:

  • attach

    Attach a forward-pre hook to accumulate E[x²] per input channel.

  • detach

    Remove hooks and leave raw sum/count on module for second-pass pickup.

Source code in src/llmcompressor/observers/imatrix.py
def __init__(self, *args, **kwargs):
    super().__init__(*args, **kwargs)
    kw = self.args.observer_kwargs
    self.maxshrink = kw.get("maxshrink", 0.95)
    self.patience = kw.get("patience", 5)
    self.grid = kw.get("grid", 20)
    self.norm = kw.get("norm", 3.0)
    self.strict = kw.get("strict", False)

    self._imatrix_sum: Optional[torch.Tensor] = None
    self._imatrix_count: torch.Tensor = torch.tensor(0, dtype=torch.int64)

    if self.grid <= 0:
        raise ValueError(f"grid must be > 0, got {self.grid}")
    if self.patience < 0:
        raise ValueError(f"patience must be >= 0, got {self.patience}")
    if not (0 <= self.maxshrink <= 1):
        raise ValueError(f"maxshrink must be in [0, 1], got {self.maxshrink}")
    if (
        not isinstance(self.norm, (int, float))
        or not math.isfinite(self.norm)
        or self.norm <= 0
    ):
        raise ValueError(f"norm must be a finite positive number, got {self.norm}")

attach

attach(module: Module) -> None

Attach a forward-pre hook to accumulate E[x²] per input channel.

If raw accumulators (_imatrix_sum / _imatrix_count) already exist on the module (second pass after IMatrixGatherer), copy them to the observer and skip hook registration.

Source code in src/llmcompressor/observers/imatrix.py
def attach(self, module: torch.nn.Module) -> None:
    """Attach a forward-pre hook to accumulate E[x²] per input channel.

    If raw accumulators (``_imatrix_sum`` / ``_imatrix_count``) already
    exist on the module (second pass after IMatrixGatherer), copy them
    to the observer and skip hook registration.
    """
    if hasattr(module, "_imatrix_sum"):
        self._imatrix_sum = module._imatrix_sum
        self._imatrix_count = module._imatrix_count
        del module._imatrix_sum
        del module._imatrix_count
        return

    if not hasattr(module, "in_features"):
        return

    in_features = module.in_features
    module._imatrix_sum = torch.zeros(in_features, dtype=IMATRIX_PRECISION)
    module._imatrix_count = torch.tensor(0, dtype=torch.int64)

    def _hook(mod, args):
        x = args[0] if isinstance(args, tuple) else args
        if isinstance(x, tuple):
            x = x[0]
        if x is None or not isinstance(x, torch.Tensor):
            return

        x_f = x.detach().to(IMATRIX_PRECISION)
        device = x_f.device
        n_tokens = math.prod(x_f.shape[:-1])
        token_sum = x_f.pow(2).sum(dim=list(range(x_f.dim() - 1)))

        mod._imatrix_sum = mod._imatrix_sum.to(device)
        mod._imatrix_count = mod._imatrix_count.to(device)

        mod._imatrix_sum.add_(token_sum)
        mod._imatrix_count += n_tokens

    module._imatrix_hook = module.register_forward_pre_hook(_hook)

detach

detach(module: Module) -> None

Remove hooks and leave raw sum/count on module for second-pass pickup.

Case 1 – accumulators present on module: leave them for next observer's attach() to pick up.

Case 2 – no accumulators (second-pass cleanup): nothing to do.

Source code in src/llmcompressor/observers/imatrix.py
def detach(self, module: torch.nn.Module) -> None:
    """Remove hooks and leave raw sum/count on module for second-pass pickup.

    Case 1 – accumulators present on module: leave them for next
    observer's ``attach()`` to pick up.

    Case 2 – no accumulators (second-pass cleanup): nothing to do.
    """
    if hasattr(module, "_imatrix_hook"):
        module._imatrix_hook.remove()
        del module._imatrix_hook

MemorylessMinMaxObserver

MemorylessMinMaxObserver(
    base_name: str,
    args: QuantizationArgs,
    **observer_kwargs,
)

Bases: Observer

Compute quantization parameters by taking the min/max of the observed value.

Source code in src/llmcompressor/observers/base.py
def __init__(
    self,
    base_name: str,
    args: QuantizationArgs,
    **observer_kwargs,
):
    super().__init__()
    self.base_name = base_name
    self.args = args

    self.args.observer_kwargs = self.args.observer_kwargs or {}
    self.args.observer_kwargs.update(observer_kwargs)

    self._fused_observers: set["Observer"] = set()

MinMaxObserver

MinMaxObserver(*args, **kwargs)

Bases: Observer

Compute quantization parameters by taking the moving average of min/max values.

Source code in src/llmcompressor/observers/min_max.py
def __init__(self, *args, **kwargs):
    super().__init__(*args, **kwargs)
    self.avg_constant = self.args.observer_kwargs.get("averaging_constant", 0.01)

MovingAverageMSEObserver

MovingAverageMSEObserver(*args, **kwargs)

Bases: Observer

Compute quantization parameters by finding the optimal min/max values which minimize the mean of quantization error squared, with moving average smoothing.

Source code in src/llmcompressor/observers/mse.py
def __init__(self, *args, **kwargs):
    super().__init__(*args, **kwargs)
    self.avg_constant = self.args.observer_kwargs.get("averaging_constant", 0.01)
    observer_kwargs = self.args.observer_kwargs
    self.maxshrink = observer_kwargs.get("maxshrink", 0.20)
    self.patience = observer_kwargs.get("patience", 5)
    self.grid = observer_kwargs.get("grid", 100.0)
    self.norm = observer_kwargs.get("norm", 2.4)

Observer

Observer(
    base_name: str,
    args: QuantizationArgs,
    **observer_kwargs,
)

Bases: InternalModule, RegistryMixin

Base class for observers which compute quantization parameters given observations of weights, activations, or attention states.

Parameters:

  • base_name (str) –

    str used to name the observer attribute

  • args (QuantizationArgs) –

    quantization args used to calibrate and quantize the observed value

  • **observer_kwargs

    keyword arguments for observer initialization

Methods:

  • attach

    Called when the observer is attached to a module.

  • detach

    Called before the observer is deleted from a module.

  • forward

    Update observer statistics from observed value.

  • fuse

    Link all observers in the list with each other for shared global_scale.

  • get_qparams

    Compute quantization parameters from accumulated statistics.

  • sync_activation_stats

    All-reduce accumulated activation statistics across DDP ranks.

  • update_statistics_from_observed

    Update internal observer statistics (min_vals, max_vals) from observed tensor.

Source code in src/llmcompressor/observers/base.py
def __init__(
    self,
    base_name: str,
    args: QuantizationArgs,
    **observer_kwargs,
):
    super().__init__()
    self.base_name = base_name
    self.args = args

    self.args.observer_kwargs = self.args.observer_kwargs or {}
    self.args.observer_kwargs.update(observer_kwargs)

    self._fused_observers: set["Observer"] = set()

attach

attach(module: Module) -> None

Called when the observer is attached to a module. Subclasses can override to register hooks or initialize state.

Parameters:

  • module (Module) –

    the module this observer is being attached to

Source code in src/llmcompressor/observers/base.py
def attach(self, module: torch.nn.Module) -> None:
    """
    Called when the observer is attached to a module.
    Subclasses can override to register hooks or initialize state.

    :param module: the module this observer is being attached to
    """
    pass

detach

detach(module: Module) -> None

Called before the observer is deleted from a module. Subclasses can override to remove hooks and clean up module attributes.

Parameters:

  • module (Module) –

    the module this observer is being removed from

Source code in src/llmcompressor/observers/base.py
def detach(self, module: torch.nn.Module) -> None:
    """
    Called before the observer is deleted from a module.
    Subclasses can override to remove hooks and clean up module attributes.

    :param module: the module this observer is being removed from
    """
    pass

forward

forward(observed: Tensor) -> Observer

Update observer statistics from observed value.

Parameters:

  • observed (Tensor) –

    value being observed

Returns:

Source code in src/llmcompressor/observers/base.py
@torch.no_grad
def forward(self, observed: torch.Tensor) -> "Observer":
    """
    Update observer statistics from observed value.

    :param observed: value being observed
    :return: self for method chaining
    """
    if observed.numel() == 0:
        return self

    observed = flatten_for_calibration(observed, self.base_name, self.args)
    self.update_statistics_from_observed(observed)
    return self

fuse staticmethod

fuse(observers: Iterable[Observer]) -> None

Link all observers in the list with each other for shared global_scale.

Parameters:

  • observers (Iterable[Observer]) –

    list of observers to fuse together

Source code in src/llmcompressor/observers/base.py
@staticmethod
def fuse(observers: Iterable["Observer"]) -> None:
    """
    Link all observers in the list with each other for shared global_scale.

    :param observers: list of observers to fuse together
    """
    observers = list(observers)
    for obs in observers:
        for other in observers:
            if other is not obs:
                obs._fused_observers.add(other)

get_qparams

get_qparams() -> QParamsDict

Compute quantization parameters from accumulated statistics.

For TENSOR_GROUP, global_scale is computed from the absmax of this observer and all fused observers. Fused observers must already have statistics — call observe_weight on all modules before calling get_qparams on any of them.

Returns:

  • QParamsDict

    dict with keys "scale", "zero_point", and "global_scale"

Source code in src/llmcompressor/observers/base.py
@torch.no_grad
def get_qparams(self) -> QParamsDict:
    """
    Compute quantization parameters from accumulated statistics.

    For TENSOR_GROUP, global_scale is computed from the absmax of
    this observer and all fused observers. Fused observers must
    already have statistics — call observe_weight on all modules
    before calling get_qparams on any of them.

    :return: dict with keys "scale", "zero_point", and "global_scale"
    """
    assert (
        self.has_statistics
    ), "No statistics available. Call observer(value) first."

    global_scale = None
    if self.args.strategy == QuantizationStrategy.TENSOR_GROUP:
        global_absmax = torch.max(-self.min_vals.min(), self.max_vals.max())
        for obs in self._fused_observers:
            assert (
                obs.has_statistics
            ), "All fused observers must be run before get_qparams."
            global_absmax = torch.max(global_absmax, -obs.min_vals.min())
            global_absmax = torch.max(global_absmax, obs.max_vals.max())
        global_scale = generate_gparam(
            -global_absmax.reshape(1), global_absmax.reshape(1)
        )

    scale, zero_point = calculate_qparams(
        min_vals=self.min_vals,
        max_vals=self.max_vals,
        quantization_args=self.args,
        global_scale=global_scale,
    )

    return {"scale": scale, "zero_point": zero_point, "global_scale": global_scale}

sync_activation_stats

sync_activation_stats() -> List[dist.Work]

All-reduce accumulated activation statistics across DDP ranks.

note: weight statistics don't need to be synced since weights

are synced across ranks, only data (activations) differs by rank.

Returns:

  • List[Work]

    list of async communication handles

Source code in src/llmcompressor/observers/base.py
def sync_activation_stats(self) -> List[dist.Work]:
    """All-reduce accumulated activation statistics across DDP ranks.

        note: weight statistics don't need to be synced since weights
    are synced across ranks, only data (activations) differs by rank.

    :return: list of async communication handles
    """
    comms = []
    for attr_name, reduce_op in self._act_sync_dict.items():
        val = getattr(self, attr_name, None)
        if val is not None:
            comms.append(
                dist.all_reduce(as_broadcastable(val), op=reduce_op, async_op=True)
            )
    return comms

update_statistics_from_observed abstractmethod

update_statistics_from_observed(observed: Tensor) -> None

Update internal observer statistics (min_vals, max_vals) from observed tensor.

Parameters:

  • observed (Tensor) –

    flattened observed value of shape (num_observations, *qparam_shape, group_size)

Source code in src/llmcompressor/observers/base.py
@abstractmethod
def update_statistics_from_observed(self, observed: torch.Tensor) -> None:
    """
    Update internal observer statistics (min_vals, max_vals) from observed tensor.

    :param observed: flattened observed value of shape
                    (num_observations, *qparam_shape, group_size)
    """
    raise NotImplementedError()

QParamsDict

Bases: TypedDict

Dictionary containing quantization parameters.

StaticMinMaxObserver

StaticMinMaxObserver(
    base_name: str,
    args: QuantizationArgs,
    **observer_kwargs,
)

Bases: MemorylessMinMaxObserver

Compute quantization parameters by taking the min/max of all observed values.

Source code in src/llmcompressor/observers/base.py
def __init__(
    self,
    base_name: str,
    args: QuantizationArgs,
    **observer_kwargs,
):
    super().__init__()
    self.base_name = base_name
    self.args = args

    self.args.observer_kwargs = self.args.observer_kwargs or {}
    self.args.observer_kwargs.update(observer_kwargs)

    self._fused_observers: set["Observer"] = set()

flatten_for_calibration

flatten_for_calibration(
    value: Tensor, base_name: str, args: QuantizationArgs
) -> torch.Tensor

Reshapes the value according to the quantization strategy for the purposes of scale/zp calibration. The value after flattening has the following shape:

(num_observations, *qparam_shape, group_size)

For block quantization, value will be zero-padded if it is not evenly divisible by block_size, so as not to distort the calculated qparams and to be compatible with vllm block-wise kernels that do not require even divisibility.

The first dim is the number of observations (usually the batch size times number of tokens), the middle dims are the dimension of the scales, and the last dim is the number of elements being quantized per group.

Parameters:

  • value (Tensor) –

    value being flattened

  • base_name (str) –

    weight, input, output, q/k/v. Used to characterize the value as being a weight, activation, or attention state

  • args (QuantizationArgs) –

    quantization args for determining how the value is flattened

Returns:

  • Tensor

    value which has been reshaped for calibration

Source code in src/llmcompressor/observers/helpers.py
def flatten_for_calibration(
    value: torch.Tensor,
    base_name: str,
    args: QuantizationArgs,
) -> torch.Tensor:
    """
    Reshapes the value according to the quantization strategy for the purposes of
    scale/zp calibration. The value after flattening has the following shape:

    `(num_observations, *qparam_shape, group_size)`

    For block quantization, value will be zero-padded if it is not evenly
    divisible by block_size, so as not to distort the calculated qparams and to be
    compatible with vllm block-wise kernels that do not require even divisibility.

    The first dim is the number of observations (usually the batch size times number of
    tokens), the middle dims are the dimension of the scales, and the last dim is the
    number of elements being quantized per group.

    :param value: value being flattened
    :param base_name: weight, input, output, q/k/v. Used to characterize the value as
        being a weight, activation, or attention state
    :param args: quantization args for determining how the value is flattened
    :return: value which has been reshaped for calibration
    """
    if base_name == "weight":
        return _flatten_weight(value, args)
    elif base_name in ("input", "output"):
        return _flatten_activation(value, args)
    elif base_name in ("q", "k", "v"):
        return _flatten_attention(value, args)
    else:
        raise ValueError(f"Unknown quantization base name: {base_name}")

fuse_weight_observers

fuse_weight_observers(model: Module)

Link weight observers across fused layer groups for shared global_scale.

For TENSOR_GROUP quantization (e.g. NVFP4), vLLM requires that fused layers (Q/K/V attention, gate/up MLP) share the same global_scale. This function links their observers so that get_qparams() computes global_scale from the combined statistics of all observers in the group.

Parameters:

  • model (Module) –

    model whose weight observers should be linked

Source code in src/llmcompressor/observers/helpers.py
def fuse_weight_observers(model: Module):
    """
    Link weight observers across fused layer groups for shared global_scale.

    For TENSOR_GROUP quantization (e.g. NVFP4), vLLM requires that fused
    layers (Q/K/V attention, gate/up MLP) share the same global_scale.
    This function links their observers so that get_qparams() computes
    global_scale from the combined statistics of all observers in the group.

    :param model: model whose weight observers should be linked
    """
    from llmcompressor.observers import Observer

    for submodule in model.modules():
        for layers_to_fuse in FUSED_LAYER_NAMES:
            if not all(hasattr(submodule, name) for name in layers_to_fuse):
                continue

            layers = [getattr(submodule, name) for name in layers_to_fuse]
            observers = []
            for layer in layers:
                obs = getattr(layer, "weight_observer", None)
                if obs is None:
                    break
                if obs.args.strategy != QuantizationStrategy.TENSOR_GROUP:
                    break
                observers.append(obs)
            else:
                Observer.fuse(observers)