多 Token 预测 (MTP)#
为什么我们需要 MTP#
MTP 通过并行预测多个 Token,将单 Token 生成模式转变为多 Token 生成,从而提升推理性能。这种方法在不损失输出质量的前提下,显著提高了生成吞吐量,并使推理速度实现成倍增长。
如何使用 MTP#
若要为 DeepSeek-V3 模型启用 MTP,请在启动服务时添加以下参数:
--speculative_config ' {"method": "mtp", "num_speculative_tokens": 1, "disable_padded_drafter_batch": False} '
num_speculative_tokens:投机 Token 的数量。如果提供该参数,模型将能够一次预测多个 Token。如果草稿模型(draft model)配置中已有该数值,则默认为配置值,否则必须手动指定。disable_padded_drafter_batch:禁用投机解码的输入填充。如果设置为 True,投机输入 Batch 可以包含不同长度的序列,但这可能仅受某些注意力后端(attention backend)支持。目前该参数仅影响 MTP 投机方法,默认为 False。
工作原理#
模块架构#
vllm_ascend
├── sample
│ ├── rejection_sample.py
├── spec_decode
│ ├── mtp_proposer.py
└───────────
1. 采样 (Sample)
rejection_sample.py:在解码过程中,主模型同时处理上一轮的输出 Token 和预测 Token(同时计算 1+k 个 Token)。第一个 Token 始终是正确的,而第二个 Token(称为 奖励 Token/Bonus Token)由于来自投机预测,具有不确定性。因此,我们采用贪婪策略 (Greedy Strategy) 和拒绝采样策略 (Rejection Sampling Strategy) 来确定是否接受该 Bonus Token。该模块包含一个
AscendRejectionSampler类,其 forward 方法实现了具体的采样逻辑。
rejection_sample.py
├── AscendRejectionSampler
│ ├── forward
2. 投机解码 (Spec_decode)
本节涵盖了投机解码的模型预处理,主要结构如下:包括加载模型、执行空运行(dummy run)以及生成 Token ID。这些步骤共同构成了单次投机解码操作的模型数据构建和前向调用。
mtp_proposer.py:配置 vLLM-Ascend 使用投机解码,其中候选 Token (Proposals) 由 DeepSeek MTP 层生成。
mtp_proposer.py
├── Proposer
│ ├── load_model
│ ├── dummy_run
│ ├── generate_token_ids
│ ├── _prepare_inputs
│ ├── _propose
算法#
1. 拒绝采样 (Reject_Sample)
贪婪策略
验证主模型生成的 Token 是否与上一轮 MTP 预测的投机 Token 匹配。如果完全匹配,则接受 Bonus Token;否则,拒绝该 Token 以及基于该次投机产生的所有后续 Token。
拒绝采样策略
该方法在拒绝采样中引入了随机性。
对于每个草稿 Token,通过验证不等式 P_target / P_draft ≥ U 是否成立来确定是否接受。其中 P_target 代表目标模型给当前草稿 Token 分配的概率,P_draft 表示草稿模型分配的概率,U 是从区间 [0, 1) 中均匀采样的随机数。
每个草稿 Token 的判定逻辑如下:如果不等式 P_target / P_draft ≥ U 成立,则该草稿 Token 被接受并输出;反之,如果 P_target / P_draft < U,则该草稿 Token 被拒绝。
当一个草稿 Token 被拒绝时,会触发恢复采样过程,从调整后的概率分布 Q = max(P_target - P_draft, 0) 中重新采样一个“恢复 Token”。在当前的 MTP 实现中,由于不提供 P_draft 且默认为 1,公式简化为:当 P_target ≥ U 时接受 Token,恢复分布变为 Q = max(P_target - 1, 0)。
2. 性能表现
如果 Bonus Token 被接受,MTP 模型将执行 (num_speculative + 1) 个 Token 的推理,包括主模型原始输出 Token 和 Bonus Token。如果被拒绝,则执行较少数量的 Token 推理,具体取决于有多少个 Token 被接受。
DFX (设计可靠性)#
方法验证#
目前,投机解码场景仅支持 ngram、eagle、eagle3 和 mtp 等方法。如果为 method 传递了错误的参数,代码将抛出错误,提示用户提供了不正确的方法。
def get_spec_decode_method(method,
vllm_config,
device,
runner):
if method == "ngram":
return NgramProposer(vllm_config, device, runner)
elif method in ["eagle", "eagle3"]:
return EagleProposer(vllm_config, device, runner)
elif method == 'mtp':
return MtpProposer(vllm_config, device, runner)
else:
raise ValueError("Unknown speculative decoding method: "
f"{method}")
整数验证#
当前的 npu_fused_infer_attention_score 算子在每轮解码中仅支持小于 16 的整数。因此,MTP 支持的最大值为 15。如果提供的值大于 15,代码将报错并提醒用户。
if self.speculative_config:
spec_token_num = self.speculative_config.num_speculative_tokens
self.decode_threshold += spec_token_num
assert self.decode_threshold <= 16, f"decode_threshold exceeded \
npu_fused_infer_attention_score TND layout's limit of 16, \
got {self.decode_threshold}"
局限性#
由于 DeepSeek 的 MTP 仅公开了单层权重,在 MTP > 1(尤其是 MTP ≥ 3)的场景下,准确性和性能无法得到有效保证。此外,受限于当前算子,MTP 最大支持 15。
在 MTP > 1 的全图模式 (fullgraph mode) 下,每个 aclgraph 的捕获大小必须是 (num_speculative_tokens + 1) 的整数倍。