Normalize attention masks for AutoRound custom datasets.
AutoRound expects at least one masked position when the calibration mask is fully
dense. When every token is marked valid, set the final position to 0 while
preserving the original dtype and shape.
More details can be found here: https://github.com/intel/auto-round/blob/50ee58c9e176e9da2a744dbe6ed220f26e80eccd/auto_round/calibration/llm.py#L315-L355
Source code in src/llmcompressor/modifiers/autoround/utils.py
| def fix_attention_mask(
mask: torch.Tensor | list[int] | list[list[int]],
) -> torch.Tensor:
"""
Normalize attention masks for AutoRound custom datasets.
AutoRound expects at least one masked position when the calibration mask is fully
dense. When every token is marked valid, set the final position to 0 while
preserving the original dtype and shape.
More details can be found here: https://github.com/intel/auto-round/blob/50ee58c9e176e9da2a744dbe6ed220f26e80eccd/auto_round/calibration/llm.py#L315-L355
"""
normalized_mask = torch.as_tensor(mask).clone()
if normalized_mask.shape[-1] == 0:
return normalized_mask
if (
normalized_mask.ndim == 4
and normalized_mask.shape[1] == 1
and normalized_mask.shape[2] == 1
):
normalized_mask = normalized_mask.squeeze(2).squeeze(1)
if normalized_mask.ndim in (3, 4):
normalized_mask = _collapse_causal_attention_mask(normalized_mask)
if normalized_mask.ndim == 1:
if torch.all(normalized_mask == 1):
normalized_mask[-1] = 0
return normalized_mask
if normalized_mask.ndim == 2:
all_ones_rows = torch.all(normalized_mask == 1, dim=1)
if torch.any(all_ones_rows):
normalized_mask[all_ones_rows, -1] = 0
return normalized_mask
raise ValueError(
"Unsupported attention mask shape for AutoRound: "
f"{tuple(normalized_mask.shape)}"
)
|