Npugraph_ex#
工作原理#
这是一种基于 FX 图的优化,可视为 aclgraph 模式的一种加速方案。
您可以在 code 获取其代码
默认 FX 图优化#
FX 图处理过程#
对于模型的中间节点,将其包含的非原位运算符替换为原位运算符,以减少计算过程中的内存移动,提升性能。
对于模型的原始输入参数,如果包含原位运算符,Dynamo 的 Functionalize 过程会将其替换为非原位运算符 + 复制运算符的形式。npugraph_ex 将逆转此过程,恢复原位运算符,减少内存移动。
FX 融合处理过程#
npugraph_ex 目前提供三种默认的算子融合处理过程,未来将添加更多。
符合替换规则的算子组合可以被替换为相应的融合算子。
您可以查看默认的融合处理过程列表
自定义融合处理过程#
用户可以在 TorchAir 中注册自定义的图融合处理过程,以修改 PyTorch FX 图。注册依赖于 register_replacement API。
以下是该 API 的声明及其使用示例。
register_replacement(search_fn, replace_fn, example_inputs, trace_fn=fwd_only, extra_check=_return_true, search_fn_pattern=None)
参数名称 |
输入/输出 |
说明 |
是否必需 |
|---|---|---|---|
search_fn |
输入 |
此函数是您希望在 FX 图中识别的算子组合或计算逻辑,例如需要融合的算子组合 |
是 |
replace_fn |
输入 |
当在目标图中找到与 search_fn 对应的组合时,此函数的计算逻辑将替换原子图,以实现算子融合或优化。 |
是 |
example_inputs |
输入 |
用于追踪 search_fn 和 replace_fn 的示例输入张量。输入的形状和数据类型应与实际场景匹配。 |
是 |
trace_fn |
输入 |
默认情况下,仅追踪前向计算图,这适用于推理阶段的优化;如果需要支持训练场景,可以提供支持反向追踪的函数。 |
否 |
extra_check |
输入 |
算子融合后的额外验证函数。该函数的输入参数必须是来自 torch._inductor.pattern_matcher 的 Match 对象,用于对匹配结果进行进一步的自定义检查,例如检查融合后的算子是否在同一流上、检查设备类型、检查输入形状等。 |
否 |
search_fn_pattern |
输入 |
通常无需提供自定义模式对象。其定义遵循原生 PyTorch MultiOutputPattern 对象的规则。传入此参数后,将不再使用 search_fn 来匹配算子组合,而是直接使用此参数作为匹配规则。 |
否 |
使用示例
import functools
import torch, torch_npu, torchair
from torch._inductor.pattern_matcher import Match
from torch._subclasses.fake_tensor import FakeTensorMode
from torchair.core.utils import logger
# Assume fusing the add operator and the npu_rms_norm operator into the npu_add_rms_norm operator
# Define a search_fn to find the operator combinations in the original FX graph before fusion.
def search_fn(x1, x2, gamma):
xOut = torch.add(x1, x2)
y, _ = torch_npu.npu_rms_norm(xOut, gamma)
return y, xOut
# Define a replace_fn, that is, a fusion operator, used to replace operator combinations in the FX graph
def replace_fn(x1, x2, gamma):
y, _, xOut = torch_npu.npu_add_rms_norm(
x1, x2, gamma
)
return y, xOut
# extra_check can pass in additional validation logic. Here, it is used to check whether the last dimension of the first input parameter x1 is a specific value; if it is not the specific value, fusion is not allowed.
def extra_check(match: Match):
x1 = match.kwargs.get("x1")
if x1 is None:
return False
if not hasattr(x1, "meta") or "val" not in x1.meta:
return False
a_shape = x1.meta["val"].shape
return a_shape[-1] == 7168
# Define some sample inputs to trace search_fn and replace_fn into an FX graph
fake_mode = FakeTensorMode()
with fake_mode:
# sizes/values don't actually matter for initial trace
# once we get a possible match we re-trace with the actual values and verify the match still holds
input_tensor = functools.partial(torch.empty, (1, 1, 2), device="npu", dtype=torch.float16)
kwargs_tensor = functools.partial(torch.empty, 2, device="npu", dtype=torch.float16)
# Call the torchair.register_replacement API with search_fn, replace_fn, and example_inputs. If there are additional validations, you can pass them in as extra_check.
torchair.register_replacement(
search_fn=search_fn,
replace_fn=replace_fn,
example_inputs=(input_tensor(), input_tensor(), kwargs_tensor()),
extra_check=extra_check
)
npugraph_ex 中的默认融合处理过程也是基于此 API 实现的。您可以在 vllm-ascend 和 npugraph_ex 代码仓库中查看更多使用此 API 的示例。
DFX#
通过复用 PyTorch 社区的 TORCH_COMPILE_DEBUG 环境变量,当设置 TORCH_COMPILE_DEBUG=1 时,将输出整个过程中的 FX 图。