Skip to content

llmcompressor.observers.base

Classes:

  • Observer

    Base class for observers which compute quantization parameters given

  • QParamsDict

    Dictionary containing quantization parameters.

Observer

Observer(
    base_name: str,
    args: QuantizationArgs,
    **observer_kwargs,
)

Bases: InternalModule, RegistryMixin

Base class for observers which compute quantization parameters given observations of weights, activations, or attention states.

Parameters:

  • base_name (str) –

    str used to name the observer attribute

  • args (QuantizationArgs) –

    quantization args used to calibrate and quantize the observed value

  • **observer_kwargs

    keyword arguments for observer initialization

Methods:

  • attach

    Called when the observer is attached to a module.

  • detach

    Called before the observer is deleted from a module.

  • forward

    Update observer statistics from observed value.

  • fuse

    Link all observers in the list with each other for shared global_scale.

  • get_qparams

    Compute quantization parameters from accumulated statistics.

  • sync_activation_stats

    All-reduce accumulated activation statistics across DDP ranks.

  • update_statistics_from_observed

    Update internal observer statistics (min_vals, max_vals) from observed tensor.

Source code in src/llmcompressor/observers/base.py
def __init__(
    self,
    base_name: str,
    args: QuantizationArgs,
    **observer_kwargs,
):
    super().__init__()
    self.base_name = base_name
    self.args = args

    self.args.observer_kwargs = self.args.observer_kwargs or {}
    self.args.observer_kwargs.update(observer_kwargs)

    self._fused_observers: set["Observer"] = set()

attach

attach(module: Module) -> None

Called when the observer is attached to a module. Subclasses can override to register hooks or initialize state.

Parameters:

  • module (Module) –

    the module this observer is being attached to

Source code in src/llmcompressor/observers/base.py
def attach(self, module: torch.nn.Module) -> None:
    """
    Called when the observer is attached to a module.
    Subclasses can override to register hooks or initialize state.

    :param module: the module this observer is being attached to
    """
    pass

detach

detach(module: Module) -> None

Called before the observer is deleted from a module. Subclasses can override to remove hooks and clean up module attributes.

Parameters:

  • module (Module) –

    the module this observer is being removed from

Source code in src/llmcompressor/observers/base.py
def detach(self, module: torch.nn.Module) -> None:
    """
    Called before the observer is deleted from a module.
    Subclasses can override to remove hooks and clean up module attributes.

    :param module: the module this observer is being removed from
    """
    pass

forward

forward(observed: Tensor) -> Observer

Update observer statistics from observed value.

Parameters:

  • observed (Tensor) –

    value being observed

Returns:

Source code in src/llmcompressor/observers/base.py
@torch.no_grad
def forward(self, observed: torch.Tensor) -> "Observer":
    """
    Update observer statistics from observed value.

    :param observed: value being observed
    :return: self for method chaining
    """
    if observed.numel() == 0:
        return self

    observed = flatten_for_calibration(observed, self.base_name, self.args)
    self.update_statistics_from_observed(observed)
    return self

fuse staticmethod

fuse(observers: Iterable[Observer]) -> None

Link all observers in the list with each other for shared global_scale.

Parameters:

  • observers (Iterable[Observer]) –

    list of observers to fuse together

Source code in src/llmcompressor/observers/base.py
@staticmethod
def fuse(observers: Iterable["Observer"]) -> None:
    """
    Link all observers in the list with each other for shared global_scale.

    :param observers: list of observers to fuse together
    """
    observers = list(observers)
    for obs in observers:
        for other in observers:
            if other is not obs:
                obs._fused_observers.add(other)

get_qparams

get_qparams() -> QParamsDict

Compute quantization parameters from accumulated statistics.

For TENSOR_GROUP, global_scale is computed from the absmax of this observer and all fused observers. Fused observers must already have statistics — call observe_weight on all modules before calling get_qparams on any of them.

Returns:

  • QParamsDict

    dict with keys "scale", "zero_point", and "global_scale"

Source code in src/llmcompressor/observers/base.py
@torch.no_grad
def get_qparams(self) -> QParamsDict:
    """
    Compute quantization parameters from accumulated statistics.

    For TENSOR_GROUP, global_scale is computed from the absmax of
    this observer and all fused observers. Fused observers must
    already have statistics — call observe_weight on all modules
    before calling get_qparams on any of them.

    :return: dict with keys "scale", "zero_point", and "global_scale"
    """
    assert (
        self.has_statistics
    ), "No statistics available. Call observer(value) first."

    global_scale = None
    if self.args.strategy == QuantizationStrategy.TENSOR_GROUP:
        global_absmax = torch.max(-self.min_vals.min(), self.max_vals.max())
        for obs in self._fused_observers:
            assert (
                obs.has_statistics
            ), "All fused observers must be run before get_qparams."
            global_absmax = torch.max(global_absmax, -obs.min_vals.min())
            global_absmax = torch.max(global_absmax, obs.max_vals.max())
        global_scale = generate_gparam(
            -global_absmax.reshape(1), global_absmax.reshape(1)
        )

    scale, zero_point = calculate_qparams(
        min_vals=self.min_vals,
        max_vals=self.max_vals,
        quantization_args=self.args,
        global_scale=global_scale,
    )

    return {"scale": scale, "zero_point": zero_point, "global_scale": global_scale}

sync_activation_stats

sync_activation_stats() -> List[dist.Work]

All-reduce accumulated activation statistics across DDP ranks.

note: weight statistics don't need to be synced since weights

are synced across ranks, only data (activations) differs by rank.

Returns:

  • List[Work]

    list of async communication handles

Source code in src/llmcompressor/observers/base.py
def sync_activation_stats(self) -> List[dist.Work]:
    """All-reduce accumulated activation statistics across DDP ranks.

        note: weight statistics don't need to be synced since weights
    are synced across ranks, only data (activations) differs by rank.

    :return: list of async communication handles
    """
    comms = []
    for attr_name, reduce_op in self._act_sync_dict.items():
        val = getattr(self, attr_name, None)
        if val is not None:
            comms.append(
                dist.all_reduce(as_broadcastable(val), op=reduce_op, async_op=True)
            )
    return comms

update_statistics_from_observed abstractmethod

update_statistics_from_observed(observed: Tensor) -> None

Update internal observer statistics (min_vals, max_vals) from observed tensor.

Parameters:

  • observed (Tensor) –

    flattened observed value of shape (num_observations, *qparam_shape, group_size)

Source code in src/llmcompressor/observers/base.py
@abstractmethod
def update_statistics_from_observed(self, observed: torch.Tensor) -> None:
    """
    Update internal observer statistics (min_vals, max_vals) from observed tensor.

    :param observed: flattened observed value of shape
                    (num_observations, *qparam_shape, group_size)
    """
    raise NotImplementedError()

QParamsDict

Bases: TypedDict

Dictionary containing quantization parameters.