Skip to content

llmcompressor.datasets.utils

Dataset utility functions for LLM compression workflows.

Provides helper functions for loading, processing, and formatting datasets used in model compression pipelines. Handles dataset splitting, tokenization, calibration data preparation, and dataloader creation for both training and one-shot calibration workflows.

Classes:

  • LengthAwareSampler

    Sample data in order of descending sequence length. Relies on input_ids or

Functions:

LengthAwareSampler

LengthAwareSampler(
    data_source: Dataset,
    num_samples: Optional[int] = None,
    batch_size: int = 1,
)

Bases: Sampler[int]

Sample data in order of descending sequence length. Relies on input_ids or decoder_input_ids column existing in dataset

Parameters:

  • data_source (Dataset) –

    dataset containing a input_ids or decoder_input_ids column

  • num_samples (Optional[int], default: None ) –

    Maximum number of samples to sample. Shorted sequence lengths are truncated first

Source code in src/llmcompressor/datasets/utils.py
def __init__(
    self,
    data_source: Dataset,
    num_samples: Optional[int] = None,
    batch_size: int = 1,
) -> None:
    self.data_source = data_source
    self._num_samples = num_samples or len(data_source)
    self.batch_size = batch_size

    if "input_ids" in data_source.column_names:
        feature_name = "input_ids"
    elif "decoder_input_ids" in data_source.column_names:
        feature_name = "decoder_input_ids"
    else:
        logger.warning(f"Could not find input ids in {data_source.column_names}")
        self.order = range(len(data_source))
        return

    lengths = [len(sample) for sample in data_source[feature_name]]
    self.order = torch.argsort(torch.tensor(lengths), descending=True).tolist()
    self._calculate_and_log_batch_stats(lengths)

get_calibration_dataloader

get_calibration_dataloader(
    dataset_args: DatasetArguments, processor: Processor
) -> DataLoader | None

Get the dataloader used for oneshot calibration.

If dataset_args.dataset is already a PyTorch DataLoader, it is returned directly, bypassing dataset loading and tokenization.

Parameters:

  • dataset_args (DatasetArguments) –

    DatasetArguments that contains the dataset parameters.

  • processor (Processor) –

    Processor or the tokenizer of the model.

Returns:

  • DataLoader | None

    PyTorch dataloader object that contains the calibration dataset, or None for data-free flows.

Source code in src/llmcompressor/datasets/utils.py
def get_calibration_dataloader(
    dataset_args: DatasetArguments,
    processor: Processor,
) -> DataLoader | None:
    """
    Get the dataloader used for oneshot calibration.

    If dataset_args.dataset is already a PyTorch DataLoader,
    it is returned directly, bypassing dataset loading and tokenization.

    :param dataset_args: DatasetArguments that contains the dataset parameters.
    :param processor: Processor or the tokenizer of the model.
    :return: PyTorch dataloader object that contains the calibration
        dataset, or None for data-free flows.
    """
    if dataset_args.dataset is None:
        # weight-only quantization or dynamic quantization
        return None

    if isinstance(dataset_args.dataset, DataLoader):
        return dataset_args.dataset

    calibration_dataset = get_processed_dataset(
        dataset_args=dataset_args,
        processor=processor,
    )

    if calibration_dataset is None:
        return None

    return format_calibration_data(dataset_args, calibration_dataset, processor)

get_processed_dataset

get_processed_dataset(
    dataset_args: DatasetArguments,
    processor: Processor | None = None,
) -> Dataset | None

Loads dataset based on dataset_args.

Parameters:

  • dataset_args (DatasetArguments) –

    DatasetArguments that contain dataset loading and processing params

  • processor (Processor | None, default: None ) –

    processor or tokenizer to use for dataset tokenization

Returns:

  • Dataset | None

    A Dataset corresponding to the single split for calibration

Source code in src/llmcompressor/datasets/utils.py
def get_processed_dataset(
    dataset_args: DatasetArguments,
    processor: Processor | None = None,
) -> Dataset | None:
    """
    Loads dataset based on dataset_args.
    :param dataset_args: DatasetArguments that contain dataset loading and
        processing params
    :param processor: processor or tokenizer to use for dataset tokenization
    :return: A Dataset corresponding to the single split for calibration
    """
    if dataset_args.dataset is None:
        logger.warning(
            "Running oneshot without calibration data. This is expected for "
            "weight-only and dynamic quantization"
        )
        return

    splits = dataset_args.splits

    match splits:
        case None:
            split_str = None
        case str():
            split_str = splits
        case dict():
            if "calibration" in splits:
                split_str = splits["calibration"]
                if len(splits) > 1:
                    ignored_keys = set(splits.keys()) - {"calibration"}
                    logger.warning(
                        f"Ignoring extra keys in splits: {list(ignored_keys)}. "
                        "Only the 'calibration' split is used."
                    )
            else:
                raise ValueError(
                    "Passing `splits` as a dict is only supported when it contains a "
                    "`'calibration'` key during the deprecation period. "
                    "Please pass a split string instead."
                )

            logger.warning(
                "Passing `splits` as a dictionary is deprecated. "
                f"Extracted split string: '{split_str}'. "
                "Please pass `splits` as a string instead."
            )
        case list():
            split_str = splits[0] if len(splits) > 0 else None
            logger.warning(
                "Passing `splits` as a list is deprecated. "
                f"Using first element: '{split_str}'. "
                "Please pass `splits` as a string instead."
            )
        case _:
            raise ValueError(
                f"Invalid splits type: {type(splits)}. Expected a split string "
                "or the deprecated `{'calibration': ...}` form."
            )

    # default to custom dataset if dataset provided isn't a string
    registry_id = (
        dataset_args.dataset if isinstance(dataset_args.dataset, str) else "custom"
    )

    dataset = dataset_args.dataset
    if hasattr(dataset, "column_names") and "input_ids" in dataset.column_names:
        # dataset is already tokenized
        return dataset
    else:
        # dataset needs to be tokenized
        dataset_manager = TextGenerationDataset.load_from_registry(
            registry_id,
            dataset_args=dataset_args,
            split=split_str,
            processor=processor,
        )
        dataset = dataset_manager()

        # If no split was specified, a DatasetDict format is typically returned.
        # Fallback to the 'train' split for backward compatibility.
        if not isinstance(dataset, Dataset):
            if "train" in dataset:
                logger.warning(
                    "No split was specified, but a multi-split dataset was loaded. "
                    "Falling back to the 'train' split for calibration."
                )
                dataset = dataset["train"]
            else:
                raise ValueError(
                    "No split specified and 'train' split not found in dataset. "
                    "Please specify `splits` explicitly."
                )

        return dataset

get_rank_partition

get_rank_partition(split: str, num_samples: int) -> str

Utility for splitting data in a distributed setting and also works in non-distributed setting

Parameters:

  • split (str) –

    the split string to partition, e.g. "train"

  • num_samples (int) –

    the total number of samples in the dataset to partition

Returns:

  • str

    a partitioned split string

    Usage example:

    DATASET_ID = "HuggingFaceH4/ultrachat_200k" DATASET_SPLIT = "train_sft" NUM_CALIBRATION_SAMPLES = 256

    split = get_rank_partition(DATASET_SPLIT, NUM_CALIBRATION_SAMPLES) ds = load_dataset( DATASET_ID, split=split, )

    for S samples and D devices, when S is not perfectly divisible by D, we give each device at least S//D samples and distribute the remaining samples as evenly as possible across all devices

Source code in src/llmcompressor/datasets/utils.py
def get_rank_partition(split: str, num_samples: int) -> str:
    """
    Utility for splitting data in a distributed setting and
    also works in non-distributed setting

    :param split: the split string to partition, e.g. "train"
    :param num_samples: the total number of samples in the dataset to partition
    :return: a partitioned split string

    Usage example:

    DATASET_ID = "HuggingFaceH4/ultrachat_200k"
    DATASET_SPLIT = "train_sft"
    NUM_CALIBRATION_SAMPLES = 256

    split = get_rank_partition(DATASET_SPLIT, NUM_CALIBRATION_SAMPLES)
    ds = load_dataset(
        DATASET_ID,
        split=split,
    )

    for S samples and D devices, when S is not perfectly divisible by D,
    we give each device at least S//D samples and distribute
    the remaining samples as evenly as possible across all devices
    """
    assert (
        "[" not in split
    ), "Split string should not already contain partitioning brackets"

    rank = dist.get_rank() if dist.is_initialized() else 0
    world_size = dist.get_world_size() if dist.is_initialized() else 1

    start, end = _get_partition_start_end(num_samples, rank, world_size)
    return f"{split}[{start}:{end}]"