Skip to content

llmcompressor.observers.imatrix

Classes:

IMatrixMSEObserver

IMatrixMSEObserver(*args, **kwargs)

Bases: Observer

MSE observer weighted by per-input-channel importance (E[x²]).

Supports CHANNEL, GROUP, and TENSOR_GROUP for weight-only Linear modules. Falls back to uniform MSE when importance data is unavailable.

Importance is accumulated as raw _imatrix_sum / _imatrix_count and synced across DDP ranks via _act_sync_dict before observation.

Methods:

  • attach

    Attach a forward-pre hook to accumulate E[x²] per input channel.

  • detach

    Remove hooks and leave raw sum/count on module for second-pass pickup.

Source code in src/llmcompressor/observers/imatrix.py
def __init__(self, *args, **kwargs):
    super().__init__(*args, **kwargs)
    kw = self.args.observer_kwargs
    self.maxshrink = kw.get("maxshrink", 0.95)
    self.patience = kw.get("patience", 5)
    self.grid = kw.get("grid", 20)
    self.norm = kw.get("norm", 3.0)
    self.strict = kw.get("strict", False)

    self._imatrix_sum: Optional[torch.Tensor] = None
    self._imatrix_count: torch.Tensor = torch.tensor(0, dtype=torch.int64)

    if self.grid <= 0:
        raise ValueError(f"grid must be > 0, got {self.grid}")
    if self.patience < 0:
        raise ValueError(f"patience must be >= 0, got {self.patience}")
    if not (0 <= self.maxshrink <= 1):
        raise ValueError(f"maxshrink must be in [0, 1], got {self.maxshrink}")
    if (
        not isinstance(self.norm, (int, float))
        or not math.isfinite(self.norm)
        or self.norm <= 0
    ):
        raise ValueError(f"norm must be a finite positive number, got {self.norm}")

attach

attach(module: Module) -> None

Attach a forward-pre hook to accumulate E[x²] per input channel.

If raw accumulators (_imatrix_sum / _imatrix_count) already exist on the module (second pass after IMatrixGatherer), copy them to the observer and skip hook registration.

Source code in src/llmcompressor/observers/imatrix.py
def attach(self, module: torch.nn.Module) -> None:
    """Attach a forward-pre hook to accumulate E[x²] per input channel.

    If raw accumulators (``_imatrix_sum`` / ``_imatrix_count``) already
    exist on the module (second pass after IMatrixGatherer), copy them
    to the observer and skip hook registration.
    """
    if hasattr(module, "_imatrix_sum"):
        self._imatrix_sum = module._imatrix_sum
        self._imatrix_count = module._imatrix_count
        del module._imatrix_sum
        del module._imatrix_count
        return

    if not hasattr(module, "in_features"):
        return

    in_features = module.in_features
    module._imatrix_sum = torch.zeros(in_features, dtype=IMATRIX_PRECISION)
    module._imatrix_count = torch.tensor(0, dtype=torch.int64)

    def _hook(mod, args):
        x = args[0] if isinstance(args, tuple) else args
        if isinstance(x, tuple):
            x = x[0]
        if x is None or not isinstance(x, torch.Tensor):
            return

        x_f = x.detach().to(IMATRIX_PRECISION)
        device = x_f.device
        n_tokens = math.prod(x_f.shape[:-1])
        token_sum = x_f.pow(2).sum(dim=list(range(x_f.dim() - 1)))

        mod._imatrix_sum = mod._imatrix_sum.to(device)
        mod._imatrix_count = mod._imatrix_count.to(device)

        mod._imatrix_sum.add_(token_sum)
        mod._imatrix_count += n_tokens

    module._imatrix_hook = module.register_forward_pre_hook(_hook)

detach

detach(module: Module) -> None

Remove hooks and leave raw sum/count on module for second-pass pickup.

Case 1 – accumulators present on module: leave them for next observer's attach() to pick up.

Case 2 – no accumulators (second-pass cleanup): nothing to do.

Source code in src/llmcompressor/observers/imatrix.py
def detach(self, module: torch.nn.Module) -> None:
    """Remove hooks and leave raw sum/count on module for second-pass pickup.

    Case 1 – accumulators present on module: leave them for next
    observer's ``attach()`` to pick up.

    Case 2 – no accumulators (second-pass cleanup): nothing to do.
    """
    if hasattr(module, "_imatrix_hook"):
        module._imatrix_hook.remove()
        del module._imatrix_hook