Bases: Observer
MSE observer weighted by per-input-channel importance (E[x²]).
Supports CHANNEL, GROUP, and TENSOR_GROUP for weight-only Linear modules.
Falls back to uniform MSE when importance data is unavailable.
Importance is accumulated as raw _imatrix_sum / _imatrix_count
and synced across DDP ranks via _act_sync_dict before observation.
Methods:
-
attach
–
Attach a forward-pre hook to accumulate E[x²] per input channel.
-
detach
–
Remove hooks and leave raw sum/count on module for second-pass pickup.
Source code in src/llmcompressor/observers/imatrix.py
| def __init__(self, *args, **kwargs):
super().__init__(*args, **kwargs)
kw = self.args.observer_kwargs
self.maxshrink = kw.get("maxshrink", 0.95)
self.patience = kw.get("patience", 5)
self.grid = kw.get("grid", 20)
self.norm = kw.get("norm", 3.0)
self.strict = kw.get("strict", False)
self._imatrix_sum: Optional[torch.Tensor] = None
self._imatrix_count: torch.Tensor = torch.tensor(0, dtype=torch.int64)
if self.grid <= 0:
raise ValueError(f"grid must be > 0, got {self.grid}")
if self.patience < 0:
raise ValueError(f"patience must be >= 0, got {self.patience}")
if not (0 <= self.maxshrink <= 1):
raise ValueError(f"maxshrink must be in [0, 1], got {self.maxshrink}")
if (
not isinstance(self.norm, (int, float))
or not math.isfinite(self.norm)
or self.norm <= 0
):
raise ValueError(f"norm must be a finite positive number, got {self.norm}")
|
attach
attach(module: Module) -> None
Attach a forward-pre hook to accumulate E[x²] per input channel.
If raw accumulators (_imatrix_sum / _imatrix_count) already
exist on the module (second pass after IMatrixGatherer), copy them
to the observer and skip hook registration.
Source code in src/llmcompressor/observers/imatrix.py
| def attach(self, module: torch.nn.Module) -> None:
"""Attach a forward-pre hook to accumulate E[x²] per input channel.
If raw accumulators (``_imatrix_sum`` / ``_imatrix_count``) already
exist on the module (second pass after IMatrixGatherer), copy them
to the observer and skip hook registration.
"""
if hasattr(module, "_imatrix_sum"):
self._imatrix_sum = module._imatrix_sum
self._imatrix_count = module._imatrix_count
del module._imatrix_sum
del module._imatrix_count
return
if not hasattr(module, "in_features"):
return
in_features = module.in_features
module._imatrix_sum = torch.zeros(in_features, dtype=IMATRIX_PRECISION)
module._imatrix_count = torch.tensor(0, dtype=torch.int64)
def _hook(mod, args):
x = args[0] if isinstance(args, tuple) else args
if isinstance(x, tuple):
x = x[0]
if x is None or not isinstance(x, torch.Tensor):
return
x_f = x.detach().to(IMATRIX_PRECISION)
device = x_f.device
n_tokens = math.prod(x_f.shape[:-1])
token_sum = x_f.pow(2).sum(dim=list(range(x_f.dim() - 1)))
mod._imatrix_sum = mod._imatrix_sum.to(device)
mod._imatrix_count = mod._imatrix_count.to(device)
mod._imatrix_sum.add_(token_sum)
mod._imatrix_count += n_tokens
module._imatrix_hook = module.register_forward_pre_hook(_hook)
|
detach
detach(module: Module) -> None
Remove hooks and leave raw sum/count on module for second-pass pickup.
Case 1 – accumulators present on module: leave them for next
observer's attach() to pick up.
Case 2 – no accumulators (second-pass cleanup): nothing to do.
Source code in src/llmcompressor/observers/imatrix.py
| def detach(self, module: torch.nn.Module) -> None:
"""Remove hooks and leave raw sum/count on module for second-pass pickup.
Case 1 – accumulators present on module: leave them for next
observer's ``attach()`` to pick up.
Case 2 – no accumulators (second-pass cleanup): nothing to do.
"""
if hasattr(module, "_imatrix_hook"):
module._imatrix_hook.remove()
del module._imatrix_hook
|