多令牌预测 (MTP)#
为何需要 MTP#
MTP 通过并行预测多个令牌来提升推理性能,从单令牌生成转向多令牌生成。这种方法显著提高了生成吞吐量,并在不牺牲输出质量的前提下,实现了推理速度的倍增加速。
如何使用 MTP#
要为 DeepSeek-V3 模型启用 MTP,请在启动服务时添加以下参数:
--speculative_config ' {"method": "mtp", "num_speculative_tokens": 1, "disable_padded_drafter_batch": False} '
num_speculative_tokens:推测性令牌的数量,如果提供,则使模型能够一次预测多个令牌。如果草稿模型配置中存在此值,则默认使用该值,否则必须提供。disable_padded_drafter_batch:禁用推测解码的输入填充。如果设置为 True,推测输入批次可以包含不同长度的序列,这可能仅受某些注意力后端支持。目前这仅影响 MTP 推测方法,默认值为 False。
工作原理#
模块架构#
vllm_ascend
├── sample
│ ├── rejection_sample.py
├── spec_decode
│ ├── mtp_proposer.py
└───────────
1. 采样
rejection_sample.py:在解码过程中,主模型同时处理上一轮的输出令牌和预测的令牌(同时计算 1+k 个令牌)。第一个令牌总是正确的,而第二个令牌(称为奖励令牌)则不确定,因为它源自推测性预测,因此我们采用贪婪策略和拒绝采样策略来决定是否应接受该奖励令牌。该模块结构包含一个
AscendRejectionSampler类,其 forward 方法实现了具体的采样逻辑。
rejection_sample.py
├── AscendRejectionSampler
│ ├── forward
2. spec_decode
本节涵盖了 spec-decode 的模型预处理,主要结构如下:包括加载模型、执行虚拟运行以及生成令牌 ID。这些步骤共同构成了单次 spec-decode 操作的模型数据构建和前向调用。
mtp_proposer.py:配置 vLLM-Ascend 使用推测解码,其中提议由 DeepSeek MTP 层生成。
mtp_proposer.py
├── Proposer
│ ├── load_model
│ ├── dummy_run
│ ├── _prepare_inputs
│ ├── _propose
算法#
1. 拒绝采样
贪婪策略
验证主模型生成的令牌是否与上一轮 MTP 预测的推测令牌匹配。如果完全匹配,则接受奖励令牌;否则,拒绝该令牌以及源自该推测的任何后续令牌。
拒绝采样策略
此方法在拒绝采样中引入了随机性。
对于每个草稿令牌,通过验证不等式 P_target / P_draft ≥ U 是否成立来决定是否接受,其中 P_target 表示目标模型分配给当前草稿令牌的概率,P_draft 表示草稿模型分配的概率,U 是从区间 [0, 1) 均匀采样的随机数。
每个草稿令牌的决策逻辑如下:如果不等式 P_target / P_draft ≥ U 成立,则草稿令牌被接受作为输出;反之,如果 P_target / P_draft < U,则草稿令牌被拒绝。
当草稿令牌被拒绝时,会触发恢复采样过程,从调整后的概率分布 Q = max(P_target - P_draft, 0) 中重新采样一个“恢复令牌”。在当前 MTP 实现中,由于未提供 P_draft 且默认为 1,公式简化为:当 P_target ≥ U 时令牌被接受,恢复分布变为 Q = max(P_target - 1, 0)。
2. 性能
如果奖励令牌被接受,MTP 模型将对 (num_speculative + 1) 个令牌执行推理,包括原始主模型输出令牌和奖励令牌。如果被拒绝,则根据接受了多少个令牌来执行更少令牌的推理。
DFX#
方法验证#
目前,spec_decode 场景仅支持 n-gram、EAGLE、EAGLE3 和 MTP 等方法。如果为方法传递了错误的参数,代码将引发错误以提醒用户提供了不正确的方法。
def get_spec_decode_method(method,
vllm_config,
device,
runner):
if method == "ngram":
return AscendNgramProposer(vllm_config, device, runner)
elif method in ["eagle", "eagle3"]:
return AscendEagleProposer(vllm_config, device, runner)
elif method == 'mtp':
return AscendMtpProposer(vllm_config, device, runner)
else:
raise ValueError("Unknown speculative decoding method: "
f"{method}")
整数验证#
当前的 npu_fused_infer_attention_score 算子每轮解码仅支持小于 16 的整数。因此,MTP 支持的最大值为 15。如果提供了大于 15 的值,代码将引发错误并提醒用户。
if self.speculative_config:
spec_token_num = self.speculative_config.num_speculative_tokens
self.decode_threshold += spec_token_num
assert self.decode_threshold <= 16, f"decode_threshold exceeded \
npu_fused_infer_attention_score TND layout's limit of 16, \
got {self.decode_threshold}"
限制#
由于 DeepSeek 的 MTP 仅暴露了单层权重,因此在 MTP > 1(尤其是 MTP ≥ 3)的场景下,准确性和性能无法得到有效保证。此外,由于当前算子限制,MTP 最多支持 15。
在 MTP > 1 的 fullgraph 模式下,每个 ACLGraph 的捕获大小必须是 (num_speculative_tokens + 1) 的整数倍。