多令牌预测 (MTP)#

为何需要 MTP#

MTP 通过并行预测多个令牌来提升推理性能,从单令牌生成转向多令牌生成。这种方法显著提高了生成吞吐量,并在不牺牲输出质量的前提下,实现了推理速度的倍增加速。

如何使用 MTP#

要为 DeepSeek-V3 模型启用 MTP,请在启动服务时添加以下参数:

--speculative_config ' {"method": "mtp", "num_speculative_tokens": 1, "disable_padded_drafter_batch": False} '

  • num_speculative_tokens:推测性令牌的数量,如果提供,则使模型能够一次预测多个令牌。如果草稿模型配置中存在此值,则默认使用该值,否则必须提供。

  • disable_padded_drafter_batch:禁用推测解码的输入填充。如果设置为 True,推测输入批次可以包含不同长度的序列,这可能仅受某些注意力后端支持。目前这仅影响 MTP 推测方法,默认值为 False。

工作原理#

模块架构#

vllm_ascend
├── sample
│   ├── rejection_sample.py
├── spec_decode
│   ├── mtp_proposer.py
└───────────

1. 采样

  • rejection_sample.py:在解码过程中,主模型同时处理上一轮的输出令牌和预测的令牌(同时计算 1+k 个令牌)。第一个令牌总是正确的,而第二个令牌(称为奖励令牌)则不确定,因为它源自推测性预测,因此我们采用贪婪策略拒绝采样策略来决定是否应接受该奖励令牌。该模块结构包含一个 AscendRejectionSampler 类,其 forward 方法实现了具体的采样逻辑。

rejection_sample.py
├── AscendRejectionSampler
│   ├── forward

2. spec_decode

本节涵盖了 spec-decode 的模型预处理,主要结构如下:包括加载模型、执行虚拟运行以及生成令牌 ID。这些步骤共同构成了单次 spec-decode 操作的模型数据构建和前向调用。

  • mtp_proposer.py:配置 vLLM-Ascend 使用推测解码,其中提议由 DeepSeek MTP 层生成。

mtp_proposer.py
├── Proposer
│   ├── load_model
│   ├── dummy_run
│   ├── _prepare_inputs
│   ├── _propose

算法#

1. 拒绝采样

  • 贪婪策略

验证主模型生成的令牌是否与上一轮 MTP 预测的推测令牌匹配。如果完全匹配,则接受奖励令牌;否则,拒绝该令牌以及源自该推测的任何后续令牌。

  • 拒绝采样策略

此方法在拒绝采样中引入了随机性。

对于每个草稿令牌,通过验证不等式 P_target / P_draft U 是否成立来决定是否接受,其中 P_target 表示目标模型分配给当前草稿令牌的概率,P_draft 表示草稿模型分配的概率,U 是从区间 [0, 1) 均匀采样的随机数。

每个草稿令牌的决策逻辑如下:如果不等式 P_target / P_draft U 成立,则草稿令牌被接受作为输出;反之,如果 P_target / P_draft < U,则草稿令牌被拒绝。

当草稿令牌被拒绝时,会触发恢复采样过程,从调整后的概率分布 Q = max(P_target - P_draft, 0) 中重新采样一个“恢复令牌”。在当前 MTP 实现中,由于未提供 P_draft 且默认为 1,公式简化为:当 P_target U 时令牌被接受,恢复分布变为 Q = max(P_target - 1, 0)

2. 性能

如果奖励令牌被接受,MTP 模型将对 (num_speculative + 1) 个令牌执行推理,包括原始主模型输出令牌和奖励令牌。如果被拒绝,则根据接受了多少个令牌来执行更少令牌的推理。

DFX#

方法验证#

  • 目前,spec_decode 场景仅支持 n-gram、EAGLE、EAGLE3 和 MTP 等方法。如果为方法传递了错误的参数,代码将引发错误以提醒用户提供了不正确的方法。

def get_spec_decode_method(method,
                           vllm_config,
                           device,
                           runner):
    if method == "ngram":
        return AscendNgramProposer(vllm_config, device, runner)
    elif method in ["eagle", "eagle3"]:
        return AscendEagleProposer(vllm_config, device, runner)
    elif method == 'mtp':
        return AscendMtpProposer(vllm_config, device, runner)
    else:
        raise ValueError("Unknown speculative decoding method: "
                         f"{method}")

整数验证#

  • 当前的 npu_fused_infer_attention_score 算子每轮解码仅支持小于 16 的整数。因此,MTP 支持的最大值为 15。如果提供了大于 15 的值,代码将引发错误并提醒用户。

if self.speculative_config:
    spec_token_num = self.speculative_config.num_speculative_tokens
    self.decode_threshold += spec_token_num
    assert self.decode_threshold <= 16, f"decode_threshold exceeded \
        npu_fused_infer_attention_score TND layout's limit of 16, \
        got {self.decode_threshold}"

限制#

  • 由于 DeepSeek 的 MTP 仅暴露了单层权重,因此在 MTP > 1(尤其是 MTP ≥ 3)的场景下,准确性和性能无法得到有效保证。此外,由于当前算子限制,MTP 最多支持 15。

  • 在 MTP > 1 的 fullgraph 模式下,每个 ACLGraph 的捕获大小必须是 (num_speculative_tokens + 1) 的整数倍。