Skip to content

llmcompressor.modifiers.quantization.quantization.mixin

Classes:

  • QuantizationMixin

    Mixin which enables a Modifier to act as a quantization config, attaching observers,

QuantizationMixin

Bases: HooksMixin

Mixin which enables a Modifier to act as a quantization config, attaching observers, calibration hooks, and compression wrappers to modifiers

Lifecycle:

  • on_initialize: QuantizationMixin.initialize_quantization
    • Attach schemes to modules
    • Attach observers to modules
    • Disable quantization until calibration starts/finishes
  • on_start: QuantizationMixin.start_calibration
    • Attach calibration hooks
    • Apply calibration status
    • Enable quantization during calibration
  • on_end: QuantizationMixin.end_calibration
    • Remove calibration hooks
    • Apply freeze status
    • Keep quantization enabled for future steps

NOTE: QuantizationMixin does not update scales and zero-points on its own, as this is not desired for all Modifiers inheriting from it. Modifier must explicitly call observe(modules, base_name="weight") then update_qparams(modules, base_name="weight"). See QuantizationModifier.on_event method for example

Parameters:

  • config_groups

    dictionary specifying quantization schemes to apply to target modules. Modules not matching a scheme target will NOT be quantized.

  • targets

    list of layer names to quantize if a scheme is provided. If unset, will contain all targets listed in config_groups. If config_groups is also unset, will default to ["Linear"] (i.e. all Linear layers will be targeted). This field is not the source of truth for finding all matching target layers in a model. Additional information can be stored in config_groups. Use self.resolved_targets instead.

  • ignore

    optional list of module class names or submodule names to not quantize even if they match a target in config_groups. Defaults to empty list.

  • scheme

    a single quantization scheme to apply to the model. This is a dictionary that supports all keys from QuantizationScheme except targets, which will be set to the targets parameter set at the modifier level. Can also be set to a dictionary of the format preset_scheme_name: targets for example: W8A8: ['Linear'] for weight and activation 8-bit.

  • kv_cache_scheme

    optional QuantizationArgs, that specify the quantization of the kv cache. If None, kv cache is not quantized. When applying kv cache quantization to transformer AutoModelForCausalLM, the kv_cache_scheme gets converted into a QuantizationScheme that: - targets the q_proj and k_proj modules of the model. The outputs of those modules are the keys and values that might be cached - quantizes the outputs of the aforementioned layers, so that keys and values are compressed before storing them in the cache There is an explicit assumption that the model contains modules with k_proj and v_proj in their names. If this is not the case and kv_cache_scheme != None, the quantization of kv cache will fail

  • weight_observer

    optional observer name for weight quantization. Overrides the default observer specified in the scheme. Valid values include "minmax", "mse", "static_minmax", "memoryless_minmax", "memoryless_mse".

  • input_observer

    optional observer name for input activation quantization. Overrides the default observer specified in the scheme. Valid values include "minmax", "mse", "static_minmax", "memoryless_minmax", "memoryless_mse".

  • output_observer

    optional observer name for output activation quantization. Overrides the default observer specified in the scheme. Valid values include "minmax", "mse", "static_minmax", "memoryless_minmax", "memoryless_mse".

  • observer

    optional dictionary to specify observers for multiple quantization types at once. Keys can be "weights", "input", or "output". Values are observer names. Example: {"weights": "MSE", "input": "MSE"}. If both individual observer parameters (weight_observer, input_observer, output_observer) and observer dict are provided, the observer dict takes precedence.

  • bypass_divisibility_checks

    if True, skip the check that weight columns are divisible by group_size for GROUP/TENSOR_GROUP. Use when your runtime (e.g. vLLM) supports non-divisible dimensions. Defaults to False.

Methods:

  • end_calibration

    Remove calibration hooks and observers, and set the model status to frozen.

  • has_config

    Determine if the user has specified a quantization config on this modifier

  • initialize_quantization

    Attach quantization schemes to modules in the model according to

  • resolve_quantization_config

    Returns the quantization config specified by this modifier

  • start_calibration

    Attach observers, register activation calibration hooks (including

  • sync_obs_act_stats

    Synchronize the activation statistics for observers

  • validate_observer

    Validate observer dictionary format. Accepts keys: 'weights', 'input', 'output'

Attributes:

  • resolved_config (QuantizationConfig) –

    Quantization config needs to be resolved just once based on

  • resolved_targets (set[str]) –

    Set of all resolved targets, i.e. all unique targets listed

resolved_config property

resolved_config: QuantizationConfig

Quantization config needs to be resolved just once based on scheme and config_groups inputs.

resolved_targets property

resolved_targets: set[str]

Set of all resolved targets, i.e. all unique targets listed in resolved quantization config. Use this property instead of the targets field, as targets can also come from config_groups depending on how recipe is configured.

end_calibration

end_calibration(model: Module)

Remove calibration hooks and observers, and set the model status to frozen. Keep quantization enabled for future operations

Parameters:

  • model (Module) –

    model to end calibration for

Source code in src/llmcompressor/modifiers/quantization/quantization/mixin.py
def end_calibration(self, model: torch.nn.Module):
    """
    Remove calibration hooks and observers, and set the model status to frozen.
    Keep quantization enabled for future operations

    :param model: model to end calibration for
    """
    self.remove_hooks(self._calibration_hooks)
    for _, module in match_named_modules(model, self.resolved_targets, self.ignore):
        freeze_module_quantization(module)  # remove observers

    model.apply(enable_quantization)  # keep quantization enabled

has_config

has_config() -> bool

Determine if the user has specified a quantization config on this modifier

Source code in src/llmcompressor/modifiers/quantization/quantization/mixin.py
def has_config(self) -> bool:
    """
    Determine if the user has specified a quantization config on this modifier
    """
    return not (
        self.config_groups is None
        and self.targets == ["Linear"]
        and self.ignore == []
        and self.scheme is None
        and self.kv_cache_scheme is None
    )

initialize_quantization

initialize_quantization(model: Module)

Attach quantization schemes to modules in the model according to the quantization config specified on this modifier

Parameters:

  • model (Module) –

    model to attach schemes and observers to

Source code in src/llmcompressor/modifiers/quantization/quantization/mixin.py
def initialize_quantization(self, model: torch.nn.Module):
    """
    Attach quantization schemes to modules in the model according to
    the quantization config specified on this modifier

    :param model: model to attach schemes and observers to
    """

    for _, module in match_named_modules(model, self.resolved_targets, self.ignore):
        reset_quantization_status(module)  # reset any previously applied qconfigs

    apply_quantization_config(model, self.resolved_config)

    if not self.bypass_divisibility_checks:
        validate_group_size_divisibility(model, self.resolved_targets, self.ignore)

    # disable quantization until calibration
    model.apply(disable_quantization)

resolve_quantization_config

resolve_quantization_config() -> QuantizationConfig

Returns the quantization config specified by this modifier

Source code in src/llmcompressor/modifiers/quantization/quantization/mixin.py
def resolve_quantization_config(self) -> QuantizationConfig:
    """
    Returns the quantization config specified by this modifier
    """
    scheme = self.scheme
    targets = self.targets
    config_groups = self.config_groups
    kv_cache_scheme = self.kv_cache_scheme
    ignore = self.ignore

    if scheme is not None and config_groups is not None:
        raise ValueError("Please specify either `scheme` or `config_groups`")

    if scheme is not None:
        # takes precedence over config_groups

        if isinstance(scheme, str) and is_preset_scheme(scheme):
            # attach targets to scheme
            scheme = {scheme: targets}

        config_groups = {}
        for idx, key in enumerate(scheme.keys()):
            if is_preset_scheme(key):
                scheme_obj = preset_name_to_scheme(key, scheme[key])
            else:
                scheme_obj = QuantizationScheme.model_validate(
                    {"targets": scheme[key], **scheme}
                )

            # Apply observer overrides if specified
            scheme_obj = self._apply_observer_overrides(scheme_obj)

            group_name = f"group_{idx}"
            config_groups[group_name] = scheme_obj

    if config_groups is None or len(config_groups) == 0:
        default_quant_scheme = QuantizationScheme(targets=targets)
        # Apply observer overrides to default scheme as well
        default_quant_scheme = self._apply_observer_overrides(default_quant_scheme)
        config_groups = {"group_0": default_quant_scheme}
    elif scheme is None:
        # Apply observer overrides to all config groups when config_groups
        # was provided directly (not derived from scheme)
        for scheme_obj in config_groups.values():
            self._apply_observer_overrides(scheme_obj)

    return QuantizationConfig(
        config_groups=config_groups,
        kv_cache_scheme=kv_cache_scheme,
        quantization_status=QuantizationStatus.INITIALIZED,
        ignore=ignore,
    )

start_calibration

start_calibration(model: Module)

Attach observers, register activation calibration hooks (including kv_cache quantization) and enable quantization as we calibrate

Parameters:

  • model (Module) –

    model to prepare for calibration

Source code in src/llmcompressor/modifiers/quantization/quantization/mixin.py
def start_calibration(self, model: torch.nn.Module):
    """
    Attach observers, register activation calibration hooks (including
    kv_cache quantization) and enable quantization as we calibrate

    :param model: model to prepare for calibration
    """
    targets = match_named_modules(model, self.resolved_targets, self.ignore)
    if targets_embeddings(model, targets):
        untie_word_embeddings(model)

    for _, module in match_named_modules(model, self.resolved_targets, self.ignore):
        self._initialize_observers(module)
        self._calibration_hooks |= self._initialize_hooks(module)
        apply_calibration_status(module)

    # Link weight observers in fused groups (Q/K/V, gate/up) for shared global_scale
    fuse_weight_observers(model)

sync_obs_act_stats

sync_obs_act_stats(modules: Iterator[Module])

Synchronize the activation statistics for observers across DDP ranks. Iterates all observers (weight, input, output, q, k, v); note: No-op when not distributed and most weight observers don't have activation statistics and thus are no-ops as well.

Parameters:

  • modules (Iterator[Module]) –

    iterable of modules to sync (e.g., from a sequential chunk)

Source code in src/llmcompressor/modifiers/quantization/quantization/mixin.py
def sync_obs_act_stats(self, modules: Iterator[torch.nn.Module]):
    """
    Synchronize the activation statistics for observers
    across DDP ranks. Iterates all observers
    (weight, input, output, q, k, v);
    note: No-op when not distributed and
    most weight observers don't have activation
    statistics and thus are no-ops as well.

    :param modules: iterable of modules to sync (e.g., from a sequential chunk)
    """
    if not is_distributed():
        return

    pending_comms = []
    synced_obs = {None}  # ignore None observers
    for module in modules:
        for base_name in ACTIVATION_OBS + ("weight",):
            observer = getattr(module, f"{base_name}_observer", None)
            if observer not in synced_obs:
                synced_obs.add(observer)
                pending_comms.extend(observer.sync_activation_stats())
    wait_for_comms(pending_comms)

validate_observer

validate_observer(value: Any) -> dict[str, str] | None

Validate observer dictionary format. Accepts keys: 'weights', 'input', 'output'

Source code in src/llmcompressor/modifiers/quantization/quantization/mixin.py
@field_validator("observer", mode="before")
def validate_observer(cls, value: Any) -> dict[str, str] | None:
    """
    Validate observer dictionary format. Accepts keys: 'weights', 'input', 'output'
    """
    if value is None:
        return value

    if not isinstance(value, dict):
        raise ValueError("`observer` must be a dictionary")

    valid_keys = {"weights", "input", "output"}
    for key in value.keys():
        if key not in valid_keys:
            raise ValueError(
                f"Invalid observer key '{key}'. Valid keys are: {valid_keys}"
            )
        if not isinstance(value[key], str):
            raise ValueError(f"Observer value for '{key}' must be a string")

    return value