Train MTP Model Online
This tutorial walks you through finetuning an MTP (Multi-Token Prediction) speculator model using online training, where hidden states are generated on-the-fly from a live vLLM server during training. This example uses Qwen/Qwen3.5-9B as the target model, but the process is the same for any model with native MTP support (e.g. Qwen3.5).
Unlike Eagle-3, DFlash, or P-EAGLE which train draft models from scratch, MTP finetuning starts from the model's native MTP head -- you convert it to speculators format, finetune on domain-specific data, and stitch the improved weights back into the verifier checkpoint.
For a ready-to-run version of this tutorial, see examples/train/mtp_qwen3_5_9b_gsm8k_online.sh.
Overview
Time required: Varies by model size. ~8 mins for Qwen3.5-9B on 2x H100 GPUs.
Prerequisites:
- Python 3.10+
- 2 CUDA-capable GPUs (one for vLLM server, one for training)
- A model with native MTP support (e.g.
Qwen/Qwen3.5-9B,Qwen/Qwen3.5-0.8B) - Training data regenerated by the target model (see Response Regeneration)
Step 0: Setup Your Environment
Create two virtual environments (recommended to keep separate so dependencies don't conflict):
# Speculators venv (for data prep and training)
uv venv speculators_venv
source speculators_venv/bin/activate
uv pip install "speculators>=0.6.0"
# vLLM venv (for serving the target model)
uv venv vllm_venv
source vllm_venv/bin/activate
uv pip install "vllm>=0.22.0"
Note: if you are using an experiment tracker (e.g. trackio, wandb, tensorboard), install it in the speculators venv manually.
Step 1: Convert Native MTP Head
Extract the model's native MTP layers into speculators format. This is unique to MTP -- other algorithms train from scratch.
# in speculators venv
python -c "
from speculators.convert import convert_model
convert_model(
model='Qwen/Qwen3.5-9B',
verifier='Qwen/Qwen3.5-9B',
algorithm='mtp',
output_path='./output/converted_mtp',
num_speculative_steps=3,
)
"
Parameters explained:
model- The source model containing native MTP layers (same as verifier for native MTP)verifier- The verifier model to attachalgorithm- Must be"mtp"for MTP conversionoutput_path- Where to save the converted speculators checkpointnum_speculative_steps- Number of tokens to predict per step (default: 3)
Expected output:
output/converted_mtp/
├── config.json # MTPSpeculatorConfig
└── model.safetensors # Extracted MTP layer weights + embed_tokens + lm_head
Step 2: Prepare Your Data
MTP requires training data generated by the target model itself. You can use a pre-regenerated dataset from HuggingFace, or create your own using the Response Regeneration pipeline.
# in speculators venv
# Download regenerated dataset
hf download \
inference-optimization/Qwen3.5-9B-responses gsm8k.jsonl \
--repo-type dataset \
--local-dir ./output/dataset
python scripts/prepare_data.py \
--model Qwen/Qwen3.5-9B \
--data ./output/dataset/gsm8k.jsonl \
--output ./output/mtp_qwen3_5_9b \
--max-samples 5000 \
--seq-length 8192
Parameters explained:
--model- The target model you want to accelerate--data- Path to regenerated dataset (jsonl file)--output- Where to save preprocessed data--max-samples- Limit samples (optional, good for testing/getting started)--seq-length- Maximum sequence length
Expected output:
output/mtp_qwen3_5_9b/
├── data-00000-of-00002.arrow # ⎤
├── data-00001-of-00002.arrow # | Processed dataset on disk
├── dataset_info.json # |
├── state.json # ⎦
└── token_freq.pt # Token frequencies
Note: This step is the same for all speculator types. For more information please see the prepare_data.py cli reference.
Step 3: Launch vLLM Server
Start vLLM to serve the verifier for hidden state extraction. The server stays running throughout training.
# in vLLM venv
CUDA_VISIBLE_DEVICES=0 python scripts/launch_vllm.py \
Qwen/Qwen3.5-9B \
--target-layer-ids 32 \
-- --port 8000
The -- separator: Anything after -- is passed directly to vLLM. Common options:
--data-parallel-size 4- Use 4 data parallel instances--tensor-parallel-size 2- Group GPUs in pairs for tensor parallelism--port 8000- Specify port (default: 8000)--gpu-memory-utilization 0.9- GPU memory to use
Wait for server to start:
INFO: Started server process [2140110]
INFO: Waiting for application startup.
INFO: Application startup complete
Note: For more information on usage, please see the launch_vllm.py cli reference.
Step 4: Train Against the Live vLLM Server
With the vLLM server running, train the MTP head. Hidden states are generated on-the-fly from the live server.
# in speculators venv (on a separate GPU)
CUDA_VISIBLE_DEVICES=1 python scripts/train.py \
--verifier-name-or-path Qwen/Qwen3.5-9B \
--data-path ./output/mtp_qwen3_5_9b \
--vllm-endpoint http://localhost:8000/v1 \
--save-path ./output/mtp_qwen3_5_9b/checkpoints \
--speculator-type mtp \
--from-pretrained ./output/converted_mtp \
--target-layer-ids 32 \
--step-weight-beta 0.6 \
--epochs 3 \
--lr 1e-4 \
--total-seq-len 8192 \
--on-missing generate \
--on-generate delete
Key MTP-specific parameters:
--speculator-type mtp- Use the MTP algorithm--from-pretrained ./output/converted_mtp- Path to the converted MTP checkpoint from Step 1--step-weight-beta 0.6- Exponential decay factor for per-step loss weights (default: 0.6)--on-missing generate- Generate hidden states on-the-fly from the vLLM server--on-generate delete- Delete generated hidden states after use (saves disk space)
Note: MTP does not require --draft-vocab-size -- it uses the full verifier vocabulary automatically. The number of speculative steps is read from the converted checkpoint's config, so --num-speculative-steps is also not needed when using --from-pretrained.
Note: There are a lot of configuration options available at this stage. We've attempted to set sensible defaults but please see the train.py cli reference to see all available options.
Step 5: Stop vLLM Server
After training is complete, stop the vLLM server:
Step 6: Stitch Finetuned Weights
After training, stitch the finetuned MTP weights back into the original verifier checkpoint. This produces a self-contained checkpoint deployable on vLLM with native MTP speculative decoding.
# in speculators venv
python scripts/stitch_mtp.py \
./output/mtp_qwen3_5_9b/checkpoints/checkpoint_best \
Qwen/Qwen3.5-9B \
--output-path ./output/stitched
Parameters explained:
- First argument: path to the finetuned MTP checkpoint
- Second argument: verifier model (HuggingFace ID or local path)
--output-path- Where to save the stitched checkpoint (defaults to{verifier-name}-stitched)
Step 7: Test Your Model
Quick Test with vLLM
Serve the stitched checkpoint with MTP speculative decoding enabled:
# in vllm venv
vllm serve ./output/stitched \
--enable-auto-tool-choice \
--tool-call-parser hermes \
--speculative-config '{"method":"mtp","num_speculative_tokens":3}' \
--no-enable-chunked-prefill \
--port 8000
See vLLM Recipes for more deployment options and configurations.
Chat with the served model
While the model is served, in a separate window run:
Verify Speculative Decoding
Check vLLM logs for speculative decoding metrics.
Next Steps
After training your model:
- Evaluate performance - See Evaluating Performance
- Deploy to production - See vLLM Recipes for deployment commands
- Fine-tune further - Use
--from-pretrained ./output/mtp_qwen3_5_9b/checkpoints/checkpoint_bestto continue training - Upload to HuggingFace - Share your model with the community