上下文并行 (CP)#

TL;DR PCP 通过序列分割加速预填充。DCP 消除 KV 缓存冗余。

ContextParallel

关于开发过程中的主要讨论,请参阅 RFC 以及该 RFC 引用或被引用的相关链接。

什么是 CP?#

上下文并行 (CP) 是一种沿序列维度在多个设备间并行计算的策略。

预填充上下文并行 (PCP) 扩展了设备的世界大小并使用专用的通信域。其主要目标是在预填充阶段对序列维度进行分区,使不同设备能同时计算序列的不同分块。KV 缓存沿序列维度跨设备分片。此方法在不同程度上影响了预填充和解码阶段的计算逻辑。

解码上下文并行 (DCP) 复用张量并行 (TP) 的通信域,且不需要额外的设备。其主要目标是通过在 TP 域内沿序列维度对 KV 缓存进行分片,消除原本会存储冗余副本的设备间的重复存储。DCP 主要影响解码逻辑,以及分块预填充和缓存预填充的逻辑。

如何使用 CP?#

详细信息请参阅 上下文并行用户指南

工作原理#

设备分布#

我们为 PCP 引入了新的通信域,并为 DCP 复用了 TP 的通信域,这是 PCP2、DCP2 和 TP4 的新设备布局。device_world

块表#

CP 对 KV 缓存存储执行序列分片。为了便于高效存储和访问,令牌以交错方式跨设备存储,交错粒度由 cp_kv_cache_interleave_size 决定,其默认值为 cp_kv_cache_interleave_size=1,也称为“令牌交错”。

鉴于 PCP 和 DCP 在 KV 缓存分片方面的行为相似,我们将它们统称为 CP。具体来说,cp_size = pcp_size * dcp_size,且 cp_rank = pcp_rank * dcp_size + dcp_rank

如图所示,块表中定义了一个虚拟块,同一 CP 设备组内的块构成一个虚拟块。虚拟块大小为 virtual_block_size = block_size * cp_size

对于任意令牌 x,参考下图,其(虚拟)块索引为 x // virtual_block_size,在虚拟块内的偏移量为 offset_within_virtual_block = x % virtual_block_size。本地块索引为 local_block_index = offset_within_virtual_block // cp_kv_cache_interleave_size,设备号为 target_rank = local_block_index % cp_size。在本地块内的偏移量为 (local_block_index // cp_size) * cp_kv_cache_interleave_size + offset_within_virtual_block % cp_kv_cache_interleave_size

BlockTable

基于上述逻辑,调整了 slot_mapping 的计算过程,并修改了每个设备上的 slot_mapping 值,以确保 KV 缓存沿序列维度分片并按预期存储在不同设备上。

当前实现要求 block_size % cp_kv_cache_interleave_size == 0

解码上下文并行 (DCP)#

如上所述,DCP 的主要功能是沿序列维度对 KV 缓存进行分片存储。其影响在于解码和分块预填充阶段的逻辑。

预填充阶段: 如图所示,在分块预填充计算期间,MLA 和 GQA 后端采用了两种不同的逻辑实现。

  • MLA 后端 中,执行上下文 KV 缓存 all_gather 操作以聚合完整的 KV 值。然后这些值与当前分块的 Q 值一起用于注意力计算。请注意,在多请求场景中,直接收集的 KV 结果在请求间是交错的。使用 reorg_kvcache 函数来重新组织 KV 缓存,确保同一请求的 KV 缓存被连续存储。

  • GQA 后端 中,沿头维度对 Q 执行 all_gather。这是因为 DCP 与 TP 通信域重叠,且 DCP 组内的 Q 头不同。然而,它们需要与本地计算的 KV 缓存交换结果以进行在线 Softmax 更新。为确保结果更新过程中的正确性,Q 值通过头维度的 all_gather 在 DCP 组内同步。在结果更新过程中,调用 cp_lse_ag_out_rs 来聚合 attn_outputattn_lse,更新结果,并对输出执行 reduce-scatter 操作。或者,我们可以使用 all-to-all 通信来交换输出和 LSE 结果,然后直接进行本地更新。这种方法与为 PCP 兼容性而调整的逻辑一致。

DCP-Prefill

解码阶段: 解码阶段的逻辑与 GQA 的分块预填充一致:首先沿 Q 头维度执行 all-gather 操作以确保 DCP 组内的一致性。使用本地 KV 缓存计算结果后,通过 cp_lse_ag_out_rs 函数更新结果。

DCP-Decode

预填充上下文并行 (PCP)#

头尾式令牌分区

PCP 需要在预填充阶段分割输入序列并确保跨设备的计算负载均衡。我们采用头尾式进行分割和连接:具体来说,首先将序列填充到长度为 2*pcp_size,然后分成 2*pcp_size 个相等的部分。第一部分与最后一部分合并,第二部分与倒数第二部分合并,依此类推,从而为每个设备分配计算上均衡的分块。此外,由于 KV 或 Q 的 allgather 聚合会导致来自不同请求的交错分块,我们计算 pcp_allgather_restore_idx 以快速恢复原始顺序。

这些逻辑在函数 _update_tokens_for_pcp 中实现。

PCP-Partition

预填充阶段:

在预填充阶段(不包括分块预填充),我们采用 all-gather KV 的方法来解决单个 GPU 上序列不完整的问题。需要注意的是,我们一次只聚合当前层的 KV 值,并且在使用后立即丢弃,以避免过高的峰值内存使用。此方法也可直接应用于 KV 缓存存储(由于 KV 缓存的分区方法与 PCP 序列分区不同,每个 GPU 都需要一份完整的 KV 值副本是不可避免的)。所有注意力后端在此逻辑上保持一致。

注意:虽然环形注意力方法也能以更低的峰值内存促进信息交换并实现计算-通信重叠,但在评估了开发复杂度高且重叠收益有限后,我们优先实现了 all-gather KV 方案。

PCP-Prefill

解码阶段:

在解码阶段,我们只需要在 DCP all-to-all 通信交换输出和 LSE 之后,于 PCP 组内添加一个 allgather,然后再进行输出更新。

PCP-Decode

分块预填充:

目前,有三种可行的分块预填充兼容性方法:AllGatherQAllGatherKVRing-Attn。由于 PCP 对查询序列和 KV 缓存都执行序列分片,我们需要确保其中一方拥有完整信息,或者采用类似 Ring-Attn 的方法顺序执行计算。Ring-Attn 的优缺点在此不赘述。

我们已在 GQA 注意力后端实现了 AllGatherQ 方法,并在 MLA 注意力后端实现了 AllGatherKV 方法。AllGatherQ 之后的工作流与解码阶段相同,而 AllGatherKV 之后的工作流与标准预填充阶段相同。详情请参考下图;具体步骤不再赘述。

一个重要注意事项:当上下文长度变得过长时,AllGatherKV 可能导致显著的峰值内存使用。为了缓解这个问题,我们采用了分段处理策略。通过预定义每轮处理的 KV 缓存最大量,我们依次完成每个分段的注意力计算和在线 softmax 更新。

PCP-ChunkedPrefill