Skip to content

llmcompressor.modifiers.transform.awq.base

Classes:

  • AWQModifier

    Implements the AWQ (Activation-Weighted Quantization) algorithm,

AWQModifier

Bases: Modifier

Implements the AWQ (Activation-Weighted Quantization) algorithm, as described in https://arxiv.org/pdf/2306.00978. The algorithm significantly reduces quantization error by protecting only 1% of the most salient weight channels.

Instead of relying on raw weight values, AWQ identifies important channels by analyzing activation patterns, focusing on the channels in the weight tensor that are most responsive to the input. To reduce quantization error, it scales these channels in a way that preserves the model's original behavior, using scaling factors computed offline from activation statistics.

Because this modifier manipulates the weights of the model, it can only be used in in one-shot and not during training. Activation ranges are determined by running a small set of calibration data through the model.

AWQModifier is a transform-based modifier, in that it does not perform quantization or compression on its own. It just scales activation channels according to a quantization scheme. It must be applied in conjunction with a modifier that inherits from QuantizationMixin in order to create a compressed checkpoint.

example recipe:

AWQModifier:
  mappings:
    - smooth_layer: "re:.*self_attn_layer_norm"
      balance_layers: ["re:.*q_proj", "re:.*k_proj", "re:.*v_proj"]
    - smooth_layer: "re:.*final_layer_norm"
      balance_layers: ["re:.*fc1"]
    # activation_hook_target specifies which submodule of the parent to hook
    # for activation caching.
    # This change is only useful for MoE models with parallel transformer blocks,
    # and one should use the default value (None) in most cases.

Lifecycle:

  • on_initialize
    • set unresolved mappings if not set by user, based on model architecture
  • (quantization config applied by subsequent QuantizationMixin's on_initialize)
  • on_start
    • resolve mappings
    • capture kwargs needed for forward passes into modules
    • set up activation cache hooks to capture input activations to balance layers
  • on sequential epoch end
    • apply smoothing to each smoothing layer
      • consume cached activations across all batches
        • clear cached activations as they are used
      • find best smoothing scale for each smoothing layer via grid search
      • apply best scales to model weights
      • raise error if any unused activations remain
  • on_end
    • re-run logic of sequential epoch end (in case of basic pipeline)
    • remove activation hooks
  • on_finalize
    • clear resolved mappings and captured activations

Parameters:

  • mappings

    list activation layers to smooth, and which layers to scale the output such that activations are smoothed. Each entry of the mapping list should be a list itself, in which the first entry is a list of layers who share the same input activation (the one to be to smoothed) and the second entry is the layer whose output is scaled to achieve the smoothing. If regex is used, it matches layers with the largest overlap in module name. Each mapping may also include an activation_hook_target: a dotted attribute path relative to the parent module (lowest common ancestor) specifying which submodule to hook for activation caching. This is useful for parallel transformer blocks where the default (hooking balance_layers[0]) would capture the wrong activations.

  • offload_device

    offload cached args to this device, which reduces memory requirements but requires more time to move data between cpu and execution device. Defaults to None, so cached args are not offloaded. Consider setting to torch.device("cpu") if you are encountering OOM errors

  • duo_scaling

    whether to use duo scaling, which uses both input activations and weights to determine the scaling factor. Defaults to True If True, both activations and weights are used. If False, only activations are used. If "both", half the grid search is performed with duo_scaling=False and the other half is performed with duo_scaling=True.

  • n_grid

    when performing the best scales grid search for each mapping, this specifies how many grid points should be used. To decrease the runtime, at the possible cost of slightly worse scales, this can be decreased. Defaults to 20

Methods:

  • on_end

    Finish calibrating by removing observers and calibration hooks.

  • on_finalize

    Clean up by clearing the activations and mapping data

  • on_initialize

    Start AWQ on the given state. This runs before quantization config has been

  • on_start

    Start AWQ on the given state. This runs after quantization mixin has been

  • validate_duo_scaling

    Validate that duo_scaling is either True, False, or 'both' (lowercase)

on_end

on_end(state: State, event: Event, **kwargs)

Finish calibrating by removing observers and calibration hooks. No qparams are updated since this is just a transform.

Source code in src/llmcompressor/modifiers/transform/awq/base.py
def on_end(self, state: State, event: Event, **kwargs):
    """
    Finish calibrating by removing observers and calibration hooks.
    No qparams are updated since this is just a transform.
    """
    self._assert_all_activations_consumed()

    self.ended_ = True

    # remove activation hooks
    self.remove_hooks()

on_finalize

on_finalize(state: State, **kwargs) -> bool

Clean up by clearing the activations and mapping data

Parameters:

  • state (State) –

    unused

Returns:

  • bool

    True

Source code in src/llmcompressor/modifiers/transform/awq/base.py
def on_finalize(self, state: State, **kwargs) -> bool:
    """
    Clean up by clearing the activations and mapping data

    :param state: unused
    :return: True
    """
    if not self.ended_:
        self.on_end(state, None)

    self._log_error_metrics()

    self._parent_args_cache.clear()
    self._smooth_activation_stats.clear()
    self._resolved_mappings.clear()
    self._error_metrics.clear()

    return True

on_initialize

on_initialize(state: State, **kwargs) -> bool

Start AWQ on the given state. This runs before quantization config has been applied

  • infer unresolved mappings based on model architecture, if not set manually

Parameters:

  • state (State) –

    state to run AWQ on

Returns:

  • bool

    True on a successful run, False otherwise

Source code in src/llmcompressor/modifiers/transform/awq/base.py
def on_initialize(self, state: State, **kwargs) -> bool:
    """
    Start AWQ on the given state. This runs before quantization config has been
    applied

    - infer unresolved mappings based on model architecture, if not set manually

    :param state: state to run AWQ on
    :return: True on a successful run, False otherwise
    """

    if self.mappings is None:
        logger.info("No AWQModifier.mappings provided, inferring from model...")
        self.mappings = get_layer_mappings_from_model(state.model)

    # Set default offload_device
    if self.offload_device == Sentinel("not_provided"):
        # Check if we have a MoE model
        if is_moe_model(state.model):
            self.offload_device = torch.device("cpu")
            logger.info(
                "MoE model detected: setting offload_device to 'cpu' by default "
                "to reduce memory usage. You can override this by explicitly "
                "setting offload_device in your recipe."
            )
        else:
            # For non-MoE models, convert sentinel to None
            # (no offloading by default)
            self.offload_device = None

    return True

on_start

on_start(state: State, event: Event, **kwargs)

Start AWQ on the given state. This runs after quantization mixin has been initialized (i.e. after quantization config has been applied)

  • resolve mappings
  • validate mappings and quant scheme
  • setup activation cache hooks

Parameters:

  • state (State) –

    state to run AWQ on

Returns:

  • True on a successful run, False otherwise

Source code in src/llmcompressor/modifiers/transform/awq/base.py
def on_start(self, state: State, event: Event, **kwargs):
    """
    Start AWQ on the given state. This runs after quantization mixin has been
    initialized (i.e. after quantization config has been applied)

    - resolve mappings
    - validate mappings and quant scheme
    - setup activation cache hooks

    :param state: state to run AWQ on
    :return: True on a successful run, False otherwise
    """
    self.started_ = True

    self._set_resolved_mappings(state.model)

    # Check for unsupported token masking with MoE up_proj -> down_proj mappings
    if state.loss_masks is not None and self._has_moe_up_down_proj_mapping():
        raise ValueError(
            "Token masking (use_loss_mask=True) is not supported with "
            "up_proj -> down_proj mappings in MoE models. The MoE routing "
            "mechanism dispatches tokens to different experts, and the loss mask "
            "cannot be properly aligned with this dispatch. Please either "
            "disable token masking or exclude the up_proj -> down_proj mapping "
            "for MoE layers from the AWQ configuration."
        )

    # Validate that duo_scaling is only used with per-channel quantization
    if self.duo_scaling is not False:

        def _validate_balance_layer(name, module):
            if (
                hasattr(module, "quantization_scheme")
                and hasattr(module.quantization_scheme, "weights")
                and module.quantization_scheme.weights.strategy
                == QuantizationStrategy.TENSOR
            ):
                raise ValueError(
                    "duo_scaling is only supported with per-channel quantization "
                    "strategies (group or channel), but found TENSOR strategy on "
                    f"layer {name}. Please set duo_scaling=False or use a "
                    "per-channel quantization strategy."
                )

        for mapping in self._resolved_mappings:
            for balance_name, balance_layer in zip(
                mapping.balance_names, mapping.balance_layers
            ):
                _validate_balance_layer(balance_name, balance_layer)

    self._setup_activation_cache_hooks()

validate_duo_scaling classmethod

validate_duo_scaling(v)

Validate that duo_scaling is either True, False, or 'both' (lowercase)

Source code in src/llmcompressor/modifiers/transform/awq/base.py
@field_validator("duo_scaling")
@classmethod
def validate_duo_scaling(cls, v):
    """Validate that duo_scaling is either True, False, or 'both' (lowercase)"""
    if v not in (True, False, "both"):
        raise ValueError(f"duo_scaling must be True, False, or 'both', got {v!r}")
    return v

get_lowest_common_ancestor_with_avoid

get_lowest_common_ancestor_with_avoid(
    balance_names: Iterator[str],
    model: Module,
    avoid=torch.nn.ModuleList,
)

Get the lowest ancestor that is not the avoided class/type. see compressed_tensors.utils.get_lowest_common_ancestor_name for detail on case handling.

NOTE: primarily used to exclude parents of type ModuleList, which don't play nicely with hooks because their forward method is never directly called for MoE models. See Qwen3MoeSparseMoeBlock for example, experts are selected based on router output and their forward method is called. https://github.com/huggingface/transformers/blob/v4.52.4/src/transformers/models/qwen3_moe/modeling_qwen3_moe.py#L233

Source code in src/llmcompressor/modifiers/transform/awq/base.py
def get_lowest_common_ancestor_with_avoid(
    balance_names: Iterator[str], model: Module, avoid=torch.nn.ModuleList
):
    """
    Get the lowest ancestor that is not the avoided class/type.
    see compressed_tensors.utils.get_lowest_common_ancestor_name
    for detail on case handling.

    NOTE: primarily used to exclude parents of type ModuleList, which don't play
    nicely with hooks because their forward method is never directly
    called for MoE models. See Qwen3MoeSparseMoeBlock for example, experts
    are selected based on router output and their forward method is called.
    https://github.com/huggingface/transformers/blob/v4.52.4/src/transformers/models/qwen3_moe/modeling_qwen3_moe.py#L233
    """
    ancestor_name = get_lowest_common_ancestor_name(balance_names)

    while True:
        if ancestor_name == "":
            return "", model
        ancestor = model.get_submodule(ancestor_name)
        if not isinstance(ancestor, avoid):
            return ancestor_name, ancestor
        ancestor_name = ".".join(ancestor_name.split(".")[:-1])