多 Token 预测 (MTP)#

为什么我们需要 MTP#

MTP 通过并行预测多个 Token,将单 Token 生成模式转变为多 Token 生成,从而提升推理性能。这种方法在不损失输出质量的前提下,显著提高了生成吞吐量,并使推理速度实现成倍增长。

如何使用 MTP#

若要为 DeepSeek-V3 模型启用 MTP,请在启动服务时添加以下参数:

--speculative_config ' {"method": "mtp", "num_speculative_tokens": 1, "disable_padded_drafter_batch": False} '

  • num_speculative_tokens:投机 Token 的数量。如果提供该参数,模型将能够一次预测多个 Token。如果草稿模型(draft model)配置中已有该数值,则默认为配置值,否则必须手动指定。

  • disable_padded_drafter_batch:禁用投机解码的输入填充。如果设置为 True,投机输入 Batch 可以包含不同长度的序列,但这可能仅受某些注意力后端(attention backend)支持。目前该参数仅影响 MTP 投机方法,默认为 False。

工作原理#

模块架构#

vllm_ascend
├── sample
│   ├── rejection_sample.py
├── spec_decode
│   ├── mtp_proposer.py
└───────────

1. 采样 (Sample)

  • rejection_sample.py:在解码过程中,主模型同时处理上一轮的输出 Token 和预测 Token(同时计算 1+k 个 Token)。第一个 Token 始终是正确的,而第二个 Token(称为 奖励 Token/Bonus Token)由于来自投机预测,具有不确定性。因此,我们采用贪婪策略 (Greedy Strategy)拒绝采样策略 (Rejection Sampling Strategy) 来确定是否接受该 Bonus Token。该模块包含一个 AscendRejectionSampler 类,其 forward 方法实现了具体的采样逻辑。

rejection_sample.py
├── AscendRejectionSampler
│   ├── forward

2. 投机解码 (Spec_decode)

本节涵盖了投机解码的模型预处理,主要结构如下:包括加载模型、执行空运行(dummy run)以及生成 Token ID。这些步骤共同构成了单次投机解码操作的模型数据构建和前向调用。

  • mtp_proposer.py:配置 vLLM-Ascend 使用投机解码,其中候选 Token (Proposals) 由 DeepSeek MTP 层生成。

mtp_proposer.py
├── Proposer
│   ├── load_model
│   ├── dummy_run
│   ├── generate_token_ids
│   ├── _prepare_inputs
│   ├── _propose

算法#

1. 拒绝采样 (Reject_Sample)

  • 贪婪策略

验证主模型生成的 Token 是否与上一轮 MTP 预测的投机 Token 匹配。如果完全匹配,则接受 Bonus Token;否则,拒绝该 Token 以及基于该次投机产生的所有后续 Token。

  • 拒绝采样策略

该方法在拒绝采样中引入了随机性。

对于每个草稿 Token,通过验证不等式 P_target / P_draft U 是否成立来确定是否接受。其中 P_target 代表目标模型给当前草稿 Token 分配的概率,P_draft 表示草稿模型分配的概率,U 是从区间 [0, 1) 中均匀采样的随机数。

每个草稿 Token 的判定逻辑如下:如果不等式 P_target / P_draft U 成立,则该草稿 Token 被接受并输出;反之,如果 P_target / P_draft < U,则该草稿 Token 被拒绝。

当一个草稿 Token 被拒绝时,会触发恢复采样过程,从调整后的概率分布 Q = max(P_target - P_draft, 0) 中重新采样一个“恢复 Token”。在当前的 MTP 实现中,由于不提供 P_draft 且默认为 1,公式简化为:当 P_target U 时接受 Token,恢复分布变为 Q = max(P_target - 1, 0)

2. 性能表现

如果 Bonus Token 被接受,MTP 模型将执行 (num_speculative + 1) 个 Token 的推理,包括主模型原始输出 Token 和 Bonus Token。如果被拒绝,则执行较少数量的 Token 推理,具体取决于有多少个 Token 被接受。

DFX (设计可靠性)#

方法验证#

  • 目前,投机解码场景仅支持 ngram、eagle、eagle3 和 mtp 等方法。如果为 method 传递了错误的参数,代码将抛出错误,提示用户提供了不正确的方法。

def get_spec_decode_method(method,
                           vllm_config,
                           device,
                           runner):
    if method == "ngram":
        return NgramProposer(vllm_config, device, runner)
    elif method in ["eagle", "eagle3"]:
        return EagleProposer(vllm_config, device, runner)
    elif method == 'mtp':
        return MtpProposer(vllm_config, device, runner)
    else:
        raise ValueError("Unknown speculative decoding method: "
                         f"{method}")

整数验证#

  • 当前的 npu_fused_infer_attention_score 算子在每轮解码中仅支持小于 16 的整数。因此,MTP 支持的最大值为 15。如果提供的值大于 15,代码将报错并提醒用户。

if self.speculative_config:
    spec_token_num = self.speculative_config.num_speculative_tokens
    self.decode_threshold += spec_token_num
    assert self.decode_threshold <= 16, f"decode_threshold exceeded \
        npu_fused_infer_attention_score TND layout's limit of 16, \
        got {self.decode_threshold}"

局限性#

  • 由于 DeepSeek 的 MTP 仅公开了单层权重,在 MTP > 1(尤其是 MTP ≥ 3)的场景下,准确性和性能无法得到有效保证。此外,受限于当前算子,MTP 最大支持 15。

  • 在 MTP > 1 的全图模式 (fullgraph mode) 下,每个 aclgraph 的捕获大小必须是 (num_speculative_tokens + 1) 的整数倍。