Skip to content

speculators.data_generation.preprocessing

Functions:

build_eagle3_dataset

build_eagle3_dataset(
    dataset: Dataset,
    processor: ProcessorLike,
    max_length: int = 2048,
    num_proc: int = 8,
    assistant_pattern: str | Pattern[str] | None = None,
    turn_dropout: bool = False,
    minimum_valid_tokens: int | None = None,
) -> HFDataset

Build EAGLE3 dataset by tokenizing conversations and creating loss masks.

Uses the processor's built-in chat template via apply_chat_template.

Args: dataset: Raw dataset with conversations processor: Processor with chat template support max_length: Maximum sequence length num_proc: Number of processes for parallel processing assistant_pattern: Optional custom regex pattern for matching assistant responses. If None, pattern will be auto-detected from chat template. turn_dropout: If True, randomly keeps first N consecutive turns per conversation minimum_valid_tokens: Number of tokens to consider for a valid sample

Source code in speculators/data_generation/preprocessing.py
def build_eagle3_dataset(
    dataset: HFDataset,
    processor: ProcessorLike,
    max_length: int = 2048,
    num_proc: int = 8,
    assistant_pattern: str | Pattern[str] | None = None,
    turn_dropout: bool = False,
    minimum_valid_tokens: int | None = None,
) -> HFDataset:
    """Build EAGLE3 dataset by tokenizing conversations and creating loss masks.

    Uses the processor's built-in chat template via apply_chat_template.

    Args:
        dataset: Raw dataset with conversations
        processor: Processor with chat template support
        max_length: Maximum sequence length
        num_proc: Number of processes for parallel processing
        assistant_pattern: Optional custom regex pattern for matching assistant
                          responses. If None, pattern will be auto-detected from
                          chat template.
        turn_dropout: If True, randomly keeps first N consecutive turns per
                     conversation
        minimum_valid_tokens: Number of tokens to consider for a valid sample
    """
    # Detect and use provided assistant message pattern
    if assistant_pattern is not None:
        log.info(f"Using custom assistant pattern: {str(assistant_pattern)[:80]}...")
    elif _supports_assistant_mask(processor):
        assistant_pattern = None  # Signal to use HF mask in _preprocess_batch
        log.info("Using HF assistant token mask for loss masking")
    else:
        assistant_pattern = _detect_assistant_pattern(processor)
        log.info(f"Detected assistant pattern: {str(assistant_pattern)[:80]}...")

    original_cols = dataset.column_names

    # Avoid CPU contention for MM processing:
    # https://github.com/vllm-project/vllm/pull/31879
    with (
        set_default_torch_num_threads()
        if isinstance(processor, ProcessorMixin)
        else nullcontext()
    ):
        dataset = dataset.map(
            lambda examples: _preprocess_batch(
                examples,
                processor,
                max_length,
                assistant_pattern,
                turn_dropout,
                minimum_valid_tokens,
            ),
            batched=True,
            num_proc=num_proc,
            batch_size=1000,
            remove_columns=original_cols,
            keep_in_memory=True,  # skip caching
        )

    dataset.set_format(type="torch")
    return dataset

load_and_preprocess_dataset

load_and_preprocess_dataset(
    target_model_path: str,
    train_data_paths: list[str],
    *,
    seq_length: int,
    build_dataset_num_proc: int = 8,
    seed: int = 0,
    max_samples: int | None = None,
    token_freq_path: Path | str = "./token_freq.pt",
    assistant_pattern: str | None = None,
    turn_dropout: bool = False,
    minimum_valid_tokens: int | None = None,
    trust_remote_code: bool = False,
) -> tuple[HFDataset, ProcessorLike]

Load, tokenize, and preprocess a dataset for EAGLE3 training.

Uses the processor's built-in chat template via apply_chat_template. Caching is handled automatically by HuggingFace datasets.

Args: target_model_path: HuggingFace model ID or local path train_data_path: Dataset name or path to JSON/JSONL file seq_length: Maximum sequence length build_dataset_num_proc: Number of processes for dataset building seed: Random seed for shuffling max_samples: Optional limit on number of samples token_freq_path: Path to save token frequency distribution cache_dir: Directory to cache HuggingFace datasets (optional) assistant_pattern: Optional custom regex pattern for matching assistant responses. If None, pattern will be auto-detected from chat template. turn_dropout: If True, randomly keeps first N consecutive turns per conversation minimum_valid_tokens: Number of tokens to consider for a valid sample trust_remote_code: If True, allows executing code from HF Hub.

Returns: Tuple of (preprocessed_dataset, processor)

Source code in speculators/data_generation/preprocessing.py
def load_and_preprocess_dataset(
    target_model_path: str,
    train_data_paths: list[str],
    *,
    seq_length: int,
    build_dataset_num_proc: int = 8,
    seed: int = 0,
    max_samples: int | None = None,
    token_freq_path: Path | str = "./token_freq.pt",  # noqa: S107
    assistant_pattern: str | None = None,
    turn_dropout: bool = False,
    minimum_valid_tokens: int | None = None,
    trust_remote_code: bool = False,
) -> tuple[HFDataset, ProcessorLike]:
    """Load, tokenize, and preprocess a dataset for EAGLE3 training.

    Uses the processor's built-in chat template via apply_chat_template.
    Caching is handled automatically by HuggingFace datasets.

    Args:
        target_model_path: HuggingFace model ID or local path
        train_data_path: Dataset name or path to JSON/JSONL file
        seq_length: Maximum sequence length
        build_dataset_num_proc: Number of processes for dataset building
        seed: Random seed for shuffling
        max_samples: Optional limit on number of samples
        token_freq_path: Path to save token frequency distribution
        cache_dir: Directory to cache HuggingFace datasets (optional)
        assistant_pattern: Optional custom regex pattern for matching assistant
                          responses. If None, pattern will be auto-detected from
                          chat template.
        turn_dropout: If True, randomly keeps first N consecutive turns per
                     conversation
        minimum_valid_tokens: Number of tokens to consider for a valid sample
        trust_remote_code: If True, allows executing code from HF Hub.

    Returns:
        Tuple of (preprocessed_dataset, processor)
    """
    if minimum_valid_tokens is not None and minimum_valid_tokens < 0:
        raise ValueError("minimum_valid_tokens must be >= 0")
    log.section("Starting dataset preprocessing")
    if minimum_valid_tokens is not None:
        log.info(
            f"Filtering samples with fewer than {minimum_valid_tokens} valid tokens"
        )

    log.subsection("Loading processor")
    processor = load_processor(target_model_path, trust_remote_code=trust_remote_code)

    if not hasattr(processor, "apply_chat_template") or processor.chat_template is None:
        raise ValueError(
            f"Processor for {target_model_path} does not support chat templates. "
            "Please use a model with a pre-configured chat template."
        )

    processed_datasets = []
    for train_data_path in train_data_paths:
        log.subsection(f"Processing {train_data_path}")
        raw_dataset, normalize_fn = load_raw_dataset(train_data_path)
        raw_dataset = raw_dataset.shuffle(seed=seed)

        if max_samples is not None and len(raw_dataset) > 3 * max_samples:
            # Reduce size to 3 * max_samples to reduce processing
            # This will then be reduced further to max_samples
            # after combining datasets and shuffling
            raw_dataset = raw_dataset.select(range(3 * max_samples))

        if normalize_fn is not None:
            raw_dataset = raw_dataset.map(
                normalize_fn,
                num_proc=build_dataset_num_proc,
                keep_in_memory=True,  # skip caching
            )

        log.info(f"Loaded {len(raw_dataset)} samples")

        if turn_dropout:
            log.info("Turn dropout enabled: randomly keeping N consecutive turns")

        preprocessed_dataset = build_eagle3_dataset(
            dataset=raw_dataset,
            processor=processor,
            max_length=seq_length,
            num_proc=build_dataset_num_proc,
            assistant_pattern=assistant_pattern,
            turn_dropout=turn_dropout,
            minimum_valid_tokens=minimum_valid_tokens,
        )
        if minimum_valid_tokens is not None:
            log.info(f"Kept {len(preprocessed_dataset)} samples after filtering")
        processed_datasets.append(preprocessed_dataset)

    combined_dataset = concatenate_datasets(processed_datasets)
    combined_dataset.shuffle(seed=seed)
    if max_samples is not None and len(combined_dataset) > max_samples:
        combined_dataset = combined_dataset.select(range(max_samples))

    log.subsection("Computing token frequency distribution")
    save_token_frequency_distribution(
        dataset=combined_dataset,
        output_path=token_freq_path,
    )

    log.subsection("Visualizing sample")
    _visualize_sample(combined_dataset, processor, idx=0)

    log.section("Dataset preprocessing complete")

    return combined_dataset, processor

load_raw_dataset

load_raw_dataset(
    train_data_path: str,
) -> tuple[HFDataset, Callable[[dict], dict] | None]

Load raw dataset from local file or HuggingFace.

Source code in speculators/data_generation/preprocessing.py
def load_raw_dataset(
    train_data_path: str,
) -> tuple[HFDataset, Callable[[dict], dict] | None]:
    """Load raw dataset from local file or HuggingFace."""
    if train_data_path.endswith((".jsonl", ".json")):
        return load_dataset("json", data_files=train_data_path, split="train"), None

    if train_data_path not in DATASET_CONFIGS:
        raise ValueError(
            f"Unsupported dataset: {train_data_path}. "
            f"Supported: local .json/.jsonl files or {list(DATASET_CONFIGS.keys())}"
        )

    config = DATASET_CONFIGS[train_data_path]
    raw_dataset = load_dataset(config.hf_path, name=config.subset, split=config.split)

    if config.filter_fn is not None:
        raw_dataset = raw_dataset.filter(config.filter_fn)

    return raw_dataset, config.normalize_fn