Skip to content

llmcompressor.modifiers.quantization.gptq

Classes:

Functions:

  • quantize_weight –

    Quantize a module weight according to the GPTQ algorithm

GPTQModifier

Bases: Modifier, QuantizationMixin

Implements the GPTQ algorithm from https://arxiv.org/abs/2210.17323. This modifier uses activations to calibrate a hessian matrix, which is then used to determine optimal quantization values and orderings for the model weights.

Sample yaml:

test_stage:
  obcq_modifiers:
    GPTQModifier:
      block_size: 128
      dampening_frac: 0.001
      offload_hessians: False
      actorder: static
      config_groups:
        group_0:
          targets:
            - "Linear"
          input_activations: null
          output_activations: null
          weights:
            num_bits: 8
            type: "int"
            symmetric: true
            strategy: group
            group_size: 128

Lifecycle:

  • on_initialize
    • apply config to model
  • on_start
    • add activation calibration hooks
    • add gptq weight calibration hooks
  • on_sequential_epoch_end
    • quantize_weight
  • on_finalize
    • remove_hooks()
    • model.apply(freeze_module_quantization)

Parameters:

  • block_size –

    Used to determine number of columns to compress in one pass

  • dampening_frac –

    Amount of dampening to apply to H, as a fraction of the diagonal norm

  • actorder –

    order in which weight columns are quantized. Defaults to "static" activation ordering, which achieves best accuracy recovery with no runtime cost. For more information, see https://github.com/vllm-project/vllm/pull/8135

  • offload_hessians –

    Set to True for decreased memory usage but increased runtime.

  • config_groups –

    dictionary specifying quantization schemes to apply to target modules. Modules not matching a scheme target will NOT be quantized.

  • targets –

    list of layer names to quantize if a scheme is provided. Defaults to Linear layers

  • ignore –

    optional list of module class names or submodule names to not quantize even if they match a target in config_groups. Defaults to empty list.

  • scheme –

    a single quantization scheme to apply to the model. This is a dictionary that supports all keys from QuantizationScheme except targets, which will be set to the targets parameter set at the modifier level. Can also be set to a dictionary of the format preset_scheme_name: targets for example: W8A8: ['Linear'] for weight and activation 8-bit.

  • kv_cache_scheme –

    optional QuantizationArgs, that specify the quantization of the kv cache. If None, kv cache is not quantized. When applying kv cache quantization to transformer AutoModelForCausalLM, the kv_cache_scheme gets converted into a QuantizationScheme that: - targets the q_proj and k_proj modules of the model. The outputs of those modules are the keys and values that might be cached - quantizes the outputs of the aforementioned layers, so that keys and values are compressed before storing them in the cache There is an explicit assumption that the model contains modules with k_proj and v_proj in their names. If this is not the case and kv_cache_scheme != None, the quantization of kv cache will fail

Methods:

  • calibrate_module –

    Calibration hook used to accumulate the hessian of the input to the module

  • compress_modules –

    Quantize modules which have been calibrated

  • on_end –

    Finish calibrating by removing observers and calibration hooks

  • on_finalize –

    disable the quantization observers used by the OBCQ algorithm

  • on_initialize –

    Initialize and run the GPTQ algorithm on the current state

calibrate_module

calibrate_module(
    module: Module,
    args: tuple[Tensor, ...],
    _output: Tensor,
)

Calibration hook used to accumulate the hessian of the input to the module

Parameters:

  • module (Module) –

    module being calibrated

  • args (tuple[Tensor, ...]) –

    inputs to the module, the first element of which is the canonical input

  • _output (Tensor) –

    uncompressed module output, unused

Source code in src/llmcompressor/modifiers/gptq/base.py
def calibrate_module(
    self,
    module: torch.nn.Module,
    args: tuple[torch.Tensor, ...],
    _output: torch.Tensor,
):
    """
    Calibration hook used to accumulate the hessian of the input to the module

    :param module: module being calibrated
    :param args: inputs to the module, the first element of which is the
        canonical input
    :param _output: uncompressed module output, unused
    """
    # Assume that first argument is the input
    inp = args[0]

    # Initialize hessian if not present
    if module not in self._num_samples:
        init_device = (
            "cpu" if self.offload_hessians else get_execution_device(module)
        )
        self._hessians[module] = make_empty_hessian(module, device=init_device)
        self._num_samples[module] = torch.zeros(
            tuple(), device=get_execution_device(module)
        )

    # Accumulate hessian with input with optional offloading
    with self._maybe_onload_hessian(module):
        self._hessians[module], self._num_samples[module] = accumulate_hessian(
            inp,
            module,
            self._hessians[module],
            self._num_samples[module],
        )

compress_modules

compress_modules()

Quantize modules which have been calibrated

Source code in src/llmcompressor/modifiers/gptq/base.py
def compress_modules(self):
    """
    Quantize modules which have been calibrated
    """
    ### Not Distributed
    if not is_distributed():
        self.compress_module_list(list(self._num_samples.keys()))
        return

    ### Distributed
    rank = dist.get_rank()
    world_size = dist.get_world_size()

    # Assign modules to ranks
    module_list, rank_to_modules, module_to_rank = greedy_bin_packing(
        list(self._hessians.keys()),
        world_size,
        item_weight_fn=lambda mod: self._hessians[mod].shape[0],
    )

    # send hessians to assigned ranks
    self._reduce_hessian_to_target_rank(module_list, module_to_rank)

    self.compress_module_list(rank_to_modules[rank])

    # broadcast compressed modules to each rank
    self._broadcast_quantized_params(module_list, module_to_rank)

on_end

on_end(state: State, event: Event, **kwargs)

Finish calibrating by removing observers and calibration hooks

Source code in src/llmcompressor/modifiers/gptq/base.py
def on_end(self, state: State, event: Event, **kwargs):
    """
    Finish calibrating by removing observers and calibration hooks
    """
    self.ended_ = True
    QuantizationMixin.end_calibration(self, state.model)
    self.remove_hooks()  # remove gptq hooks

on_finalize

on_finalize(state: State, **kwargs) -> bool

disable the quantization observers used by the OBCQ algorithm

Parameters:

  • state (State) –

    session state storing input model and calibration data

Source code in src/llmcompressor/modifiers/gptq/base.py
def on_finalize(self, state: State, **kwargs) -> bool:
    """
    disable the quantization observers used by the OBCQ algorithm

    :param state: session state storing input model and calibration data
    """
    if not self.ended_:
        self.on_end(state, None)

    if len(self._num_samples) > 0:
        raise ValueError(f"Failed to compress {len(self._num_samples)} modules")

    self._hessians = dict()
    self._num_samples = dict()

    return True

on_initialize

on_initialize(state: State, **kwargs) -> bool

Initialize and run the GPTQ algorithm on the current state

Parameters:

  • state (State) –

    session state storing input model and calibration data

Source code in src/llmcompressor/modifiers/gptq/base.py
def on_initialize(self, state: State, **kwargs) -> bool:
    """
    Initialize and run the GPTQ algorithm on the current state

    :param state: session state storing input model and calibration data
    """
    # apply config to model and prepare calibration hooks
    if QuantizationMixin.has_config(self):
        QuantizationMixin.initialize_quantization(self, state.model)

    # prepare module names
    self._module_names = {
        m: name
        for name, m in match_named_modules(
            state.model, self.resolved_targets, self.ignore
        )
    }

    return True

quantize_weight

quantize_weight(
    module: Module,
    quant_args: QuantizationArgs,
    hessian: Tensor,
    blocksize: int = 128,
    percdamp: float = 0.01,
) -> tuple[
    float,
    torch.Tensor,
    torch.Tensor,
    torch.Tensor | None,
    torch.Tensor,
]

Quantize a module weight according to the GPTQ algorithm

Parameters:

  • module (Module) –

    module with weight being quantized

  • quant_args (QuantizationArgs) –

    quantization arguments used to find quantization parameters

  • hessian (Tensor) –

    preaccumulated hessian for quantization

  • blocksize (int, default: 128 ) –

    chunk size of quantization updates

  • percdamp (float, default: 0.01 ) –

    dampening factor on hessian diagonal

Returns:

  • tuple[float, Tensor, Tensor, Tensor | None, Tensor] –

    loss, quantized_weight, scale, zero_point, g_idx

Source code in src/llmcompressor/modifiers/gptq/gptq_quantize.py
def quantize_weight(
    module: torch.nn.Module,
    quant_args: QuantizationArgs,
    hessian: torch.Tensor,
    blocksize: int = 128,
    percdamp: float = 0.01,
) -> tuple[float, torch.Tensor, torch.Tensor, torch.Tensor | None, torch.Tensor]:
    """
    Quantize a module weight according to the GPTQ algorithm

    :param module: module with weight being quantized
    :param quant_args: quantization arguments used to find quantization parameters
    :param hessian: preaccumulated hessian for quantization
    :param blocksize: chunk size of quantization updates
    :param percdamp: dampening factor on hessian diagonal
    :return: loss, quantized_weight, scale, zero_point, g_idx
    """
    strategy = quant_args.strategy
    actorder = quant_args.actorder
    final_shape = module.weight.shape
    final_dtype = module.weight.dtype
    W = module.weight.clone()
    H = hessian

    observer = module.weight_observer

    W = W.to(dtype=GPTQ_PRECISION)
    num_rows = W.shape[0]
    num_columns = W.shape[1]

    if actorder == ActivationOrdering.GROUP and strategy not in (
        QuantizationStrategy.GROUP,
        QuantizationStrategy.TENSOR_GROUP,
    ):
        logger.warning(
            "ActivationOrdering.GROUP requires a grouped quantization strategy; "
            "falling back to actorder=None for this module."
        )
        actorder = None

    # handle activation ordering
    if actorder:
        W, H, perm = _apply_activation_ordering(W, H)

    # handle g_idx and activation ordering
    if actorder == ActivationOrdering.GROUP:
        # actually need scale/zp for permuted weight for this format
        observer(W)
        # use identity g_idx (invert permutation later)

    # handle g_idx
    if strategy in (
        QuantizationStrategy.GROUP,
        QuantizationStrategy.TENSOR_GROUP,
        QuantizationStrategy.BLOCK,
    ):
        # mapping from column index to group index
        divisor = (
            quant_args.group_size
            if strategy != QuantizationStrategy.BLOCK
            else quant_args.block_structure[1]
        )
        g_idx = torch.arange(num_columns, device=W.device, dtype=torch.int) // divisor

        if actorder == ActivationOrdering.WEIGHT:
            g_idx = g_idx[perm]

    qparams = observer.get_qparams()
    scale, zero_point, global_scale = (
        qparams["scale"],
        qparams["zero_point"],
        qparams["global_scale"],
    )

    # sparsity mask
    sparsity = tensor_sparsity(W)
    preserve_zeros = sparsity >= SPARSITY_THRESHOLD
    W_nz_mask = (
        (~torch.isclose(W, torch.zeros(1, device=W.device).float())).float()
        if preserve_zeros
        else None
    )

    losses = torch.zeros(num_rows, device=module.weight.device)

    # mask dead hessian values
    dead = torch.diag(H) == 0
    H[dead, dead] = 1
    W[:, dead] = 0

    # compute inverse hessian in place to save memory
    try:
        damp = percdamp * torch.mean(torch.diag(H))
        diag = torch.arange(H.shape[0], device=H.device)
        H[diag, diag] += damp
        H = torch.linalg.cholesky(H)
        H = torch.cholesky_inverse(H)
        H = torch.linalg.cholesky(H, upper=True)
        Hinv = H
    except torch._C._LinAlgError:
        logger.warning(
            "Failed to invert hessian due to numerical instability. Consider "
            "increasing GPTQModifier.dampening_frac, increasing the number "
            "of calibration samples, or shuffling the calibration dataset. "
            "Falling back to round-to-nearest for this module."
        )
        Hinv = H = torch.eye(num_columns, dtype=H.dtype, device=H.device)

    # See section 3.4 of https://arxiv.org/abs/2203.07259
    for i1 in range(0, num_columns, blocksize):
        i2 = min(i1 + blocksize, num_columns)
        count = i2 - i1

        W1 = W[:, i1:i2].clone()
        Q1 = torch.zeros_like(W1)
        Err1 = torch.zeros_like(W1)
        losses1 = torch.zeros_like(W1)
        Hinv1 = Hinv[i1:i2, i1:i2]

        if preserve_zeros:
            W1_nz_mask = W_nz_mask[:, i1:i2]

        for i in range(count):
            w = W1[:, i]
            d = Hinv1[i, i]
            q = w.clone()

            # quantize column
            if strategy == QuantizationStrategy.TENSOR:
                q = fake_quantize(
                    q, scale, zero_point, quant_args, global_scale=global_scale
                )
            elif strategy == QuantizationStrategy.CHANNEL:
                q = fake_quantize(
                    q,
                    scale[:, 0],
                    zero_point[:, 0],
                    quant_args,
                    global_scale=global_scale,
                )
            # apply global scale to scale quant scale
            elif strategy in (
                QuantizationStrategy.GROUP,
                QuantizationStrategy.TENSOR_GROUP,
            ):
                # get the group index for the current column
                column_idx = i1 + i
                group_index = g_idx[column_idx]

                # Since we're only applying quantization to a slice, this
                # ends up being a channelwise application
                altered_qargs = copy(quant_args)
                altered_qargs.strategy = QuantizationStrategy.CHANNEL

                q = fake_quantize(
                    q,
                    scale[:, group_index],
                    zero_point[:, group_index],
                    altered_qargs,
                    global_scale=global_scale,
                )
            elif strategy == QuantizationStrategy.BLOCK:
                column_idx = i1 + i
                block_column_idx = g_idx[column_idx]
                q = fake_quantize(
                    q.unsqueeze(1),
                    scale[:, block_column_idx : block_column_idx + 1],
                    zero_point[:, block_column_idx : block_column_idx + 1],
                    quant_args,
                    global_scale=global_scale,
                ).squeeze(1)
            else:
                raise ValueError(
                    f"Quantization strategy is not supported for GPTQ: {strategy}"
                )

            # propagate column error
            Q1[:, i] = q
            losses1[:, i] = (w - q) ** 2 / d**2

            err1 = (w - q) / d
            w1_err = err1.unsqueeze(1).matmul(Hinv1[i, i:].unsqueeze(0))
            if preserve_zeros:
                W1[:, i:] -= w1_err * W1_nz_mask[:, i:]
            else:
                W1[:, i:] -= w1_err
            Err1[:, i] = err1

        # propagate block error
        W[:, i1:i2] = Q1
        losses += torch.sum(losses1, 1) / 2

        w_err = Err1.matmul(Hinv[i1:i2, i2:])
        if preserve_zeros:
            W[:, i2:] -= w_err * W_nz_mask[:, i2:]
        else:
            W[:, i2:] -= w_err

    if actorder:
        # restore original permutation
        invperm = torch.argsort(perm)
        W = W[:, invperm]

    W = W.reshape(final_shape).to(final_dtype)

    loss = torch.sum(losses).item()
    q_param_dict = {
        "weight": W,
        "weight_scale": scale.to(dtype=final_dtype),
        "weight_zero_point": zero_point.to(dtype=quant_args.zp_dtype),
    }
    if global_scale:
        q_param_dict["weight_global_scale"] = global_scale.to(dtype=final_dtype)
    if actorder == ActivationOrdering.GROUP:
        q_param_dict["weight_g_idx"] = g_idx[invperm]
    return (loss, q_param_dict)