Skip to content

llmcompressor.modifiers.autoround.utils

Functions:

fix_attention_mask

fix_attention_mask(
    mask: Tensor | list[int] | list[list[int]],
) -> torch.Tensor

Normalize attention masks for AutoRound custom datasets.

AutoRound expects at least one masked position when the calibration mask is fully dense. When every token is marked valid, set the final position to 0 while preserving the original dtype and shape. More details can be found here: https://github.com/intel/auto-round/blob/50ee58c9e176e9da2a744dbe6ed220f26e80eccd/auto_round/calibration/llm.py#L315-L355

Source code in src/llmcompressor/modifiers/autoround/utils.py
def fix_attention_mask(
    mask: torch.Tensor | list[int] | list[list[int]],
) -> torch.Tensor:
    """
    Normalize attention masks for AutoRound custom datasets.

    AutoRound expects at least one masked position when the calibration mask is fully
    dense. When every token is marked valid, set the final position to 0 while
    preserving the original dtype and shape.
    More details can be found here: https://github.com/intel/auto-round/blob/50ee58c9e176e9da2a744dbe6ed220f26e80eccd/auto_round/calibration/llm.py#L315-L355
    """
    normalized_mask = torch.as_tensor(mask).clone()
    if normalized_mask.shape[-1] == 0:
        return normalized_mask

    if (
        normalized_mask.ndim == 4
        and normalized_mask.shape[1] == 1
        and normalized_mask.shape[2] == 1
    ):
        normalized_mask = normalized_mask.squeeze(2).squeeze(1)

    if normalized_mask.ndim in (3, 4):
        normalized_mask = _collapse_causal_attention_mask(normalized_mask)

    if normalized_mask.ndim == 1:
        if torch.all(normalized_mask == 1):
            normalized_mask[-1] = 0
        return normalized_mask

    if normalized_mask.ndim == 2:
        all_ones_rows = torch.all(normalized_mask == 1, dim=1)
        if torch.any(all_ones_rows):
            normalized_mask[all_ones_rows, -1] = 0
        return normalized_mask

    raise ValueError(
        "Unsupported attention mask shape for AutoRound: "
        f"{tuple(normalized_mask.shape)}"
    )