多Token预测(MTP)#

为什么需要 MTP#

MTP 通过并行化多个 Token 的预测,将单 Token 生成模式转变为多 Token 生成模式,从而提升推理性能。这种方法在不损失输出质量的前提下,显著提高了生成吞吐量,并实现了推理速度的倍增。

如何使用 MTP#

若要在 DeepSeek-V3 模型中启用 MTP,请在启动服务时添加以下参数:

--speculative_config ' {"method": "mtp", "num_speculative_tokens": 1, "disable_padded_drafter_batch": False} '

  • num_speculative_tokens: 猜测 Token 的数量。如果提供该参数,模型将能够一次性预测多个 Token。如果存在草图模型(draft model)配置,则默认为该配置中的数量,否则该参数是必需的。

  • disable_padded_drafter_batch: 禁用猜测式解码的输入填充。如果设置为 True,猜测输入 Batch 可以包含不同长度的序列,这通常仅受特定的 Attention 后端支持。目前该参数仅影响 MTP 猜测方法,默认值为 False。

工作原理#

模块架构#

vllm_ascend
├── sample
│   ├── rejection_sample.py
├── spec_decode
│   ├── mtp_proposer.py
└───────────

1. 采样 (sample)

  • rejection_sample.py: 在解码过程中,主模型会同时处理上一轮输出的 Token 和预测出的 Token(同时计算 1+k 个 Token)。第一个 Token 始终是正确的,而第二个 Token(被称为 Bonus Token)由于源自猜测预测,具有不确定性。因此,我们采用贪婪策略 (Greedy Strategy)拒绝采样策略 (Rejection Sampling Strategy) 来决定是否接受 Bonus Token。该模块结构包含一个 AscendRejectionSampler 类,通过其 forward 方法实现具体的采样逻辑。

rejection_sample.py
├── AscendRejectionSampler
│   ├── forward

2. 猜测式解码 (spec_decode)

本节包含猜测式解码的模型预处理,其主要结构如下:加载模型、执行空运行(dummy run)以及生成 Token ID。这些步骤共同构成了单次猜测式解码操作的模型数据构建和前向调用。

  • mtp_proposer.py: 配置 vLLM-Ascend 使用猜测式解码,其中的候选 Token(proposals)由 DeepSeek MTP 层生成。

mtp_proposer.py
├── Proposer
│   ├── load_model
│   ├── dummy_run
│   ├── _prepare_inputs
│   ├── _propose

算法#

1. 拒绝采样 (Rejection Sampling)

  • 贪婪策略 (Greedy Strategy)

验证主模型生成的 Token 是否与上一轮 MTP 预测的猜测 Token 匹配。如果完全匹配,则接受 Bonus Token;否则,拒绝该 Token 以及随后基于该猜测生成的所有 Token。

  • 拒绝采样策略 (Rejection Sampling Strategy)

此方法在拒绝采样中引入了随机性。

对于每个草图 Token(draft token),通过验证不等式 P_target / P_draft U 是否成立来确定是否接受。其中 P_target 表示目标模型分配给当前草图 Token 的概率,P_draft 表示草图模型分配的概率,U 是从区间 [0, 1) 中均匀采样的随机数。

每个草图 Token 的决策逻辑如下:如果不等式 P_target / P_draft U 成立,则接受该草图 Token 作为输出;反之,如果 P_target / P_draft < U,则拒绝该草图 Token。

当草图 Token 被拒绝时,会触发恢复采样过程,从调整后的概率分布 Q = max(P_target - P_draft, 0) 中重新采样一个“恢复 Token”。在当前的 MTP 实现中,由于不提供 P_draft 且默认值为 1,公式简化为:当 P_target U 时接受 Token,恢复分布变为 Q = max(P_target - 1, 0)

2. 性能 (Performance)

如果 Bonus Token 被接受,MTP 模型将对 (num_speculative + 1) 个 Token 进行推理,包括原始主模型输出的 Token 和 Bonus Token。如果被拒绝,则对较少的 Token 进行推理,具体取决于有多少 Token 被接受。

DFX (设计卓越性)#

方法验证#

  • 目前,spec_decode 场景仅支持 n-gram, EAGLE, EAGLE3 和 MTP 等方法。如果为 method 传递了错误的参数,代码将抛出错误,提醒用户提供了不正确的方法。

def get_spec_decode_method(method,
                           vllm_config,
                           device,
                           runner):
    if method == "ngram":
        return AscendNgramProposer(vllm_config, device, runner)
    elif method in ["eagle", "eagle3"]:
        return AscendEagleProposer(vllm_config, device, runner)
    elif method == 'mtp':
        return AscendMtpProposer(vllm_config, device, runner)
    else:
        raise ValueError("Unknown speculative decoding method: "
                         f"{method}")

整数验证#

  • 目前的 npu_fused_infer_attention_score 算子每轮解码仅支持小于 16 的整数。因此,MTP 支持的最大值为 15。如果提供的值大于 15,代码将抛出错误并提醒用户。

if self.speculative_config:
    spec_token_num = self.speculative_config.num_speculative_tokens
    self.decode_threshold += spec_token_num
    assert self.decode_threshold <= 16, f"decode_threshold exceeded \
        npu_fused_infer_attention_score TND layout's limit of 16, \
        got {self.decode_threshold}"

限制#

  • 由于 DeepSeek 的 MTP 仅公开了单层权重,在 MTP > 1(特别是 MTP ≥ 3)的场景下,准确性和性能无法得到有效保证。此外,受限于当前算子限制,MTP 最大支持 15。

  • 在 MTP > 1 的全图模式(fullgraph mode)下,每个 ACLGraph 的捕获大小(capture size)必须是 (num_speculative_tokens + 1) 的整数倍。