Skip to content

Train MTP Model Online

This tutorial walks you through finetuning an MTP (Multi-Token Prediction) speculator model using online training, where hidden states are generated on-the-fly from a live vLLM server during training. This example uses Qwen/Qwen3.5-9B as the target model, but the process is the same for any model with native MTP support (e.g. Qwen3.5).

Unlike Eagle-3, DFlash, or P-EAGLE which train draft models from scratch, MTP finetuning starts from the model's native MTP head -- you convert it to speculators format, finetune on domain-specific data, and stitch the improved weights back into the verifier checkpoint.

For a ready-to-run version of this tutorial, see examples/train/mtp_qwen3_5_9b_gsm8k_online.sh.

Overview

Time required: Varies by model size. ~8 mins for Qwen3.5-9B on 2x H100 GPUs.

Prerequisites:

  • Python 3.10+
  • 2 CUDA-capable GPUs (one for vLLM server, one for training)
  • A model with native MTP support (e.g. Qwen/Qwen3.5-9B, Qwen/Qwen3.5-0.8B)
  • Training data regenerated by the target model (see Response Regeneration)

Step 0: Setup Your Environment

Create two virtual environments (recommended to keep separate so dependencies don't conflict):

# Speculators venv (for data prep and training)
uv venv speculators_venv
source speculators_venv/bin/activate
uv pip install "speculators>=0.6.0"
# vLLM venv (for serving the target model)
uv venv vllm_venv
source vllm_venv/bin/activate
uv pip install "vllm>=0.22.0"

Note: if you are using an experiment tracker (e.g. trackio, wandb, tensorboard), install it in the speculators venv manually.

Step 1: Convert Native MTP Head

Extract the model's native MTP layers into speculators format. This is unique to MTP -- other algorithms train from scratch.

# in speculators venv
python -c "
from speculators.convert import convert_model
convert_model(
    model='Qwen/Qwen3.5-9B',
    verifier='Qwen/Qwen3.5-9B',
    algorithm='mtp',
    output_path='./output/converted_mtp',
    num_speculative_steps=3,
)
"

Parameters explained:

  • model - The source model containing native MTP layers (same as verifier for native MTP)
  • verifier - The verifier model to attach
  • algorithm - Must be "mtp" for MTP conversion
  • output_path - Where to save the converted speculators checkpoint
  • num_speculative_steps - Number of tokens to predict per step (default: 3)

Expected output:

output/converted_mtp/
├── config.json             # MTPSpeculatorConfig
└── model.safetensors       # Extracted MTP layer weights + embed_tokens + lm_head

Step 2: Prepare Your Data

MTP requires training data generated by the target model itself. You can use a pre-regenerated dataset from HuggingFace, or create your own using the Response Regeneration pipeline.

# in speculators venv
# Download regenerated dataset
hf download \
  inference-optimization/Qwen3.5-9B-responses gsm8k.jsonl \
  --repo-type dataset \
  --local-dir ./output/dataset

python scripts/prepare_data.py \
  --model Qwen/Qwen3.5-9B \
  --data ./output/dataset/gsm8k.jsonl \
  --output ./output/mtp_qwen3_5_9b \
  --max-samples 5000 \
  --seq-length 8192

Parameters explained:

  • --model - The target model you want to accelerate
  • --data - Path to regenerated dataset (jsonl file)
  • --output - Where to save preprocessed data
  • --max-samples - Limit samples (optional, good for testing/getting started)
  • --seq-length - Maximum sequence length

Expected output:

output/mtp_qwen3_5_9b/
├── data-00000-of-00002.arrow    #  ⎤
├── data-00001-of-00002.arrow    #  | Processed dataset on disk
├── dataset_info.json            #  |
├── state.json                   #  ⎦
└── token_freq.pt                # Token frequencies

Note: This step is the same for all speculator types. For more information please see the prepare_data.py cli reference.

Step 3: Launch vLLM Server

Start vLLM to serve the verifier for hidden state extraction. The server stays running throughout training.

# in vLLM venv
CUDA_VISIBLE_DEVICES=0 python scripts/launch_vllm.py \
  Qwen/Qwen3.5-9B \
  --target-layer-ids 32 \
  -- --port 8000

The -- separator: Anything after -- is passed directly to vLLM. Common options:

  • --data-parallel-size 4 - Use 4 data parallel instances
  • --tensor-parallel-size 2 - Group GPUs in pairs for tensor parallelism
  • --port 8000 - Specify port (default: 8000)
  • --gpu-memory-utilization 0.9 - GPU memory to use

Wait for server to start:

INFO:     Started server process [2140110]
INFO:     Waiting for application startup.
INFO:     Application startup complete

Note: For more information on usage, please see the launch_vllm.py cli reference.

Step 4: Train Against the Live vLLM Server

With the vLLM server running, train the MTP head. Hidden states are generated on-the-fly from the live server.

# in speculators venv (on a separate GPU)
CUDA_VISIBLE_DEVICES=1 python scripts/train.py \
  --verifier-name-or-path Qwen/Qwen3.5-9B \
  --data-path ./output/mtp_qwen3_5_9b \
  --vllm-endpoint http://localhost:8000/v1 \
  --save-path ./output/mtp_qwen3_5_9b/checkpoints \
  --speculator-type mtp \
  --from-pretrained ./output/converted_mtp \
  --target-layer-ids 32 \
  --step-weight-beta 0.6 \
  --epochs 3 \
  --lr 1e-4 \
  --total-seq-len 8192 \
  --on-missing generate \
  --on-generate delete

Key MTP-specific parameters:

  • --speculator-type mtp - Use the MTP algorithm
  • --from-pretrained ./output/converted_mtp - Path to the converted MTP checkpoint from Step 1
  • --step-weight-beta 0.6 - Exponential decay factor for per-step loss weights (default: 0.6)
  • --on-missing generate - Generate hidden states on-the-fly from the vLLM server
  • --on-generate delete - Delete generated hidden states after use (saves disk space)

Note: MTP does not require --draft-vocab-size -- it uses the full verifier vocabulary automatically. The number of speculative steps is read from the converted checkpoint's config, so --num-speculative-steps is also not needed when using --from-pretrained.

Note: There are a lot of configuration options available at this stage. We've attempted to set sensible defaults but please see the train.py cli reference to see all available options.

Step 5: Stop vLLM Server

After training is complete, stop the vLLM server:

# Press Ctrl+C in the vLLM terminal

Step 6: Stitch Finetuned Weights

After training, stitch the finetuned MTP weights back into the original verifier checkpoint. This produces a self-contained checkpoint deployable on vLLM with native MTP speculative decoding.

# in speculators venv
python scripts/stitch_mtp.py \
  ./output/mtp_qwen3_5_9b/checkpoints/checkpoint_best \
  Qwen/Qwen3.5-9B \
  --output-path ./output/stitched

Parameters explained:

  • First argument: path to the finetuned MTP checkpoint
  • Second argument: verifier model (HuggingFace ID or local path)
  • --output-path - Where to save the stitched checkpoint (defaults to {verifier-name}-stitched)

Step 7: Test Your Model

Quick Test with vLLM

Serve the stitched checkpoint with MTP speculative decoding enabled:

# in vllm venv
vllm serve ./output/stitched \
  --enable-auto-tool-choice \
  --tool-call-parser hermes \
  --speculative-config '{"method":"mtp","num_speculative_tokens":3}' \
  --no-enable-chunked-prefill \
  --port 8000

See vLLM Recipes for more deployment options and configurations.

Chat with the served model

While the model is served, in a separate window run:

# in vllm venv
vllm chat --url http://localhost:8000/v1

Verify Speculative Decoding

Check vLLM logs for speculative decoding metrics.

Next Steps

After training your model:

  1. Evaluate performance - See Evaluating Performance
  2. Deploy to production - See vLLM Recipes for deployment commands
  3. Fine-tune further - Use --from-pretrained ./output/mtp_qwen3_5_9b/checkpoints/checkpoint_best to continue training
  4. Upload to HuggingFace - Share your model with the community