data_generation_offline.py
Generates training data for speculator models by extracting hidden states from a running vLLM server. Connects to a vLLM endpoint via the OpenAI-compatible API and saves output as individual .safetensors files for offline training.
Features
- Automatic resumption — Detects existing
.safetensorsfiles in the output directory and skips already-completed samples, so interrupted runs can be resumed without reprocessing. - Error handling with auto-retries — Failed requests are automatically retried up to
--max-retriestimes. Samples that still fail are skipped by default, allowing the rest of the dataset to complete. - Consecutive failure detection — Aborts early after
--max-consecutive-errorsconsecutive failures to avoid silently churning through the dataset when the server is unreachable. - Async concurrency — Sends multiple requests to the vLLM server in parallel, controlled by
--concurrency, for high throughput. - Output validation — Optional
--validate-outputsflag verifies that saved hidden states match expected token IDs and sequence lengths.
Basic Usage
python scripts/data_generation_offline.py \
--preprocessed-data ./preprocessed_dataset \
--output ./training_data \
--max-samples 5000
Arguments
Model Arguments
-
--endpoint(str, default:http://localhost:8000/v1) The address of the vLLM instance to use for hidden states generation. The vLLM instance must be configured for hidden states extraction (see launch_vllm.py). -
--model(str, default:None) HuggingFace model ID or local path for the target model. Used for verification only - the model is auto-detected from the vLLM endpoint.
Data Arguments
-
--preprocessed-data(str, required) Path to preprocessed dataset (produced by prepare_data.py). -
--max-samples(int, default:None) Maximum number of samples to process. IfNone, processes all samples.
Output Arguments
--output(str, default:None) Directory to save generated.safetensorsfiles. Defaults to<preprocessed-data>/hidden_states.
Hidden States Generation Arguments
-
--concurrency(int, default:32) Number of active vLLM requests at a time. The number of async workers is set to2 * concurrency. -
--validate-outputs(flag) Load generated safetensor files and verify that output token IDs match prompt tokens and hidden states sequence length matches the number of tokens. -
--request-timeout(float) Timeout in seconds for each individual vLLM request. -
--max-retries(int) Maximum number of retry attempts per request on failure. -
--fail-on-error(flag) Abort when a request fails after all retries. By default, failed samples are skipped. -
--max-consecutive-errors(int, default: value of--concurrency) Abort after this many consecutive sample failures (each sample already retried--max-retriestimes). Prevents silently churning through the entire dataset when the server is down. Ignored when--fail-on-erroris set.