Dataset utility functions for LLM compression workflows.
Provides helper functions for loading, processing, and formatting datasets used
in model compression pipelines. Handles dataset splitting, tokenization,
calibration data preparation, and dataloader creation for both training and
one-shot calibration workflows.
Classes:
-
LengthAwareSampler
–
Sample data in order of descending sequence length. Relies on input_ids or
Functions:
LengthAwareSampler
LengthAwareSampler(
data_source: Dataset,
num_samples: Optional[int] = None,
batch_size: int = 1,
)
Bases: Sampler[int]
Sample data in order of descending sequence length. Relies on input_ids or
decoder_input_ids column existing in dataset
Parameters:
-
data_source
(Dataset)
–
dataset containing a input_ids or decoder_input_ids column
-
num_samples
(Optional[int], default:
None
)
–
Maximum number of samples to sample. Shorted sequence lengths
are truncated first
Source code in src/llmcompressor/datasets/utils.py
| def __init__(
self,
data_source: Dataset,
num_samples: Optional[int] = None,
batch_size: int = 1,
) -> None:
self.data_source = data_source
self._num_samples = num_samples or len(data_source)
self.batch_size = batch_size
if "input_ids" in data_source.column_names:
feature_name = "input_ids"
elif "decoder_input_ids" in data_source.column_names:
feature_name = "decoder_input_ids"
else:
logger.warning(f"Could not find input ids in {data_source.column_names}")
self.order = range(len(data_source))
return
lengths = [len(sample) for sample in data_source[feature_name]]
self.order = torch.argsort(torch.tensor(lengths), descending=True).tolist()
self._calculate_and_log_batch_stats(lengths)
|
get_calibration_dataloader
get_calibration_dataloader(
dataset_args: DatasetArguments, processor: Processor
) -> DataLoader | None
Get the dataloader used for oneshot calibration.
If dataset_args.dataset is already a PyTorch DataLoader,
it is returned directly, bypassing dataset loading and tokenization.
Parameters:
-
dataset_args
(DatasetArguments)
–
DatasetArguments that contains the dataset parameters.
-
processor
(Processor)
–
Processor or the tokenizer of the model.
Returns:
-
DataLoader | None
–
PyTorch dataloader object that contains the calibration
dataset, or None for data-free flows.
Source code in src/llmcompressor/datasets/utils.py
| def get_calibration_dataloader(
dataset_args: DatasetArguments,
processor: Processor,
) -> DataLoader | None:
"""
Get the dataloader used for oneshot calibration.
If dataset_args.dataset is already a PyTorch DataLoader,
it is returned directly, bypassing dataset loading and tokenization.
:param dataset_args: DatasetArguments that contains the dataset parameters.
:param processor: Processor or the tokenizer of the model.
:return: PyTorch dataloader object that contains the calibration
dataset, or None for data-free flows.
"""
if dataset_args.dataset is None:
# weight-only quantization or dynamic quantization
return None
if isinstance(dataset_args.dataset, DataLoader):
return dataset_args.dataset
calibration_dataset = get_processed_dataset(
dataset_args=dataset_args,
processor=processor,
)
if calibration_dataset is None:
return None
return format_calibration_data(dataset_args, calibration_dataset, processor)
|
get_processed_dataset
get_processed_dataset(
dataset_args: DatasetArguments,
processor: Processor | None = None,
) -> Dataset | None
Loads dataset based on dataset_args.
Parameters:
-
dataset_args
(DatasetArguments)
–
DatasetArguments that contain dataset loading and
processing params
-
processor
(Processor | None, default:
None
)
–
processor or tokenizer to use for dataset tokenization
Returns:
-
Dataset | None
–
A Dataset corresponding to the single split for calibration
Source code in src/llmcompressor/datasets/utils.py
| def get_processed_dataset(
dataset_args: DatasetArguments,
processor: Processor | None = None,
) -> Dataset | None:
"""
Loads dataset based on dataset_args.
:param dataset_args: DatasetArguments that contain dataset loading and
processing params
:param processor: processor or tokenizer to use for dataset tokenization
:return: A Dataset corresponding to the single split for calibration
"""
if dataset_args.dataset is None:
logger.warning(
"Running oneshot without calibration data. This is expected for "
"weight-only and dynamic quantization"
)
return
splits = dataset_args.splits
match splits:
case None:
split_str = None
case str():
split_str = splits
case dict():
if "calibration" in splits:
split_str = splits["calibration"]
if len(splits) > 1:
ignored_keys = set(splits.keys()) - {"calibration"}
logger.warning(
f"Ignoring extra keys in splits: {list(ignored_keys)}. "
"Only the 'calibration' split is used."
)
else:
raise ValueError(
"Passing `splits` as a dict is only supported when it contains a "
"`'calibration'` key during the deprecation period. "
"Please pass a split string instead."
)
logger.warning(
"Passing `splits` as a dictionary is deprecated. "
f"Extracted split string: '{split_str}'. "
"Please pass `splits` as a string instead."
)
case list():
split_str = splits[0] if len(splits) > 0 else None
logger.warning(
"Passing `splits` as a list is deprecated. "
f"Using first element: '{split_str}'. "
"Please pass `splits` as a string instead."
)
case _:
raise ValueError(
f"Invalid splits type: {type(splits)}. Expected a split string "
"or the deprecated `{'calibration': ...}` form."
)
# default to custom dataset if dataset provided isn't a string
registry_id = (
dataset_args.dataset if isinstance(dataset_args.dataset, str) else "custom"
)
dataset = dataset_args.dataset
if hasattr(dataset, "column_names") and "input_ids" in dataset.column_names:
# dataset is already tokenized
return dataset
else:
# dataset needs to be tokenized
dataset_manager = TextGenerationDataset.load_from_registry(
registry_id,
dataset_args=dataset_args,
split=split_str,
processor=processor,
)
dataset = dataset_manager()
# If no split was specified, a DatasetDict format is typically returned.
# Fallback to the 'train' split for backward compatibility.
if not isinstance(dataset, Dataset):
if "train" in dataset:
logger.warning(
"No split was specified, but a multi-split dataset was loaded. "
"Falling back to the 'train' split for calibration."
)
dataset = dataset["train"]
else:
raise ValueError(
"No split specified and 'train' split not found in dataset. "
"Please specify `splits` explicitly."
)
return dataset
|
get_rank_partition
get_rank_partition(split: str, num_samples: int) -> str
Utility for splitting data in a distributed setting and
also works in non-distributed setting
Parameters:
-
split
(str)
–
the split string to partition, e.g. "train"
-
num_samples
(int)
–
the total number of samples in the dataset to partition
Returns:
-
str
–
a partitioned split string
Usage example:
DATASET_ID = "HuggingFaceH4/ultrachat_200k"
DATASET_SPLIT = "train_sft"
NUM_CALIBRATION_SAMPLES = 256
split = get_rank_partition(DATASET_SPLIT, NUM_CALIBRATION_SAMPLES)
ds = load_dataset(
DATASET_ID,
split=split,
)
for S samples and D devices, when S is not perfectly divisible by D,
we give each device at least S//D samples and distribute
the remaining samples as evenly as possible across all devices
Source code in src/llmcompressor/datasets/utils.py
| def get_rank_partition(split: str, num_samples: int) -> str:
"""
Utility for splitting data in a distributed setting and
also works in non-distributed setting
:param split: the split string to partition, e.g. "train"
:param num_samples: the total number of samples in the dataset to partition
:return: a partitioned split string
Usage example:
DATASET_ID = "HuggingFaceH4/ultrachat_200k"
DATASET_SPLIT = "train_sft"
NUM_CALIBRATION_SAMPLES = 256
split = get_rank_partition(DATASET_SPLIT, NUM_CALIBRATION_SAMPLES)
ds = load_dataset(
DATASET_ID,
split=split,
)
for S samples and D devices, when S is not perfectly divisible by D,
we give each device at least S//D samples and distribute
the remaining samples as evenly as possible across all devices
"""
assert (
"[" not in split
), "Split string should not already contain partitioning brackets"
rank = dist.get_rank() if dist.is_initialized() else 0
world_size = dist.get_world_size() if dist.is_initialized() else 1
start, end = _get_partition_start_end(num_samples, rank, world_size)
return f"{split}[{start}:{end}]"
|