Npugraph_ex#

工作原理#

这是一种基于 FX 图的优化,可视为 aclgraph 模式的一种加速方案。

您可以在 code 获取其代码

默认 FX 图优化#

FX 图处理过程#

  • 对于模型的中间节点,将其包含的非原位运算符替换为原位运算符,以减少计算过程中的内存移动,提升性能。

  • 对于模型的原始输入参数,如果包含原位运算符,Dynamo 的 Functionalize 过程会将其替换为非原位运算符 + 复制运算符的形式。npugraph_ex 将逆转此过程,恢复原位运算符,减少内存移动。

FX 融合处理过程#

npugraph_ex 目前提供三种默认的算子融合处理过程,未来将添加更多。

符合替换规则的算子组合可以被替换为相应的融合算子。

您可以查看默认的融合处理过程列表

自定义融合处理过程#

用户可以在 TorchAir 中注册自定义的图融合处理过程,以修改 PyTorch FX 图。注册依赖于 register_replacement API。

以下是该 API 的声明及其使用示例。

register_replacement(search_fn, replace_fn, example_inputs, trace_fn=fwd_only, extra_check=_return_true, search_fn_pattern=None)

参数名称

输入/输出

说明

是否必需

search_fn

输入

此函数是您希望在 FX 图中识别的算子组合或计算逻辑,例如需要融合的算子组合

replace_fn

输入

当在目标图中找到与 search_fn 对应的组合时,此函数的计算逻辑将替换原子图,以实现算子融合或优化。

example_inputs

输入

用于追踪 search_fn 和 replace_fn 的示例输入张量。输入的形状和数据类型应与实际场景匹配。

trace_fn

输入

默认情况下,仅追踪前向计算图,这适用于推理阶段的优化;如果需要支持训练场景,可以提供支持反向追踪的函数。

extra_check

输入

算子融合后的额外验证函数。该函数的输入参数必须是来自 torch._inductor.pattern_matcher 的 Match 对象,用于对匹配结果进行进一步的自定义检查,例如检查融合后的算子是否在同一流上、检查设备类型、检查输入形状等。

search_fn_pattern

输入

通常无需提供自定义模式对象。其定义遵循原生 PyTorch MultiOutputPattern 对象的规则。传入此参数后,将不再使用 search_fn 来匹配算子组合,而是直接使用此参数作为匹配规则。

使用示例

import functools
import torch, torch_npu, torchair

from torch._inductor.pattern_matcher import Match
from torch._subclasses.fake_tensor import FakeTensorMode
from torchair.core.utils import logger

# Assume fusing the add operator and the npu_rms_norm operator into the npu_add_rms_norm operator
# Define a search_fn to find the operator combinations in the original FX graph before fusion.
def search_fn(x1, x2, gamma):
    xOut = torch.add(x1, x2)
    y, _ = torch_npu.npu_rms_norm(xOut, gamma)
    return y, xOut

# Define a replace_fn, that is, a fusion operator, used to replace operator combinations in the FX graph
def replace_fn(x1, x2, gamma):
    y, _, xOut = torch_npu.npu_add_rms_norm(
        x1, x2, gamma
    )
    return y, xOut

# extra_check can pass in additional validation logic. Here, it is used to check whether the last dimension of the first input parameter x1 is a specific value; if it is not the specific value, fusion is not allowed.
def extra_check(match: Match):
    x1 = match.kwargs.get("x1")

    if x1 is None:
        return False 
    if not hasattr(x1, "meta") or "val" not in x1.meta:
        return False

    a_shape = x1.meta["val"].shape
    return a_shape[-1] == 7168 


# Define some sample inputs to trace search_fn and replace_fn into an FX graph
fake_mode = FakeTensorMode()
with fake_mode:
    # sizes/values don't actually matter for initial trace
    # once we get a possible match we re-trace with the actual values and verify the match still holds
    input_tensor = functools.partial(torch.empty, (1, 1, 2), device="npu", dtype=torch.float16)
    kwargs_tensor = functools.partial(torch.empty, 2, device="npu", dtype=torch.float16)

    # Call the torchair.register_replacement API with search_fn, replace_fn, and example_inputs. If there are additional validations, you can pass them in as extra_check.
    torchair.register_replacement(
        search_fn=search_fn,
        replace_fn=replace_fn,
        example_inputs=(input_tensor(), input_tensor(), kwargs_tensor()),
        extra_check=extra_check
    )

npugraph_ex 中的默认融合处理过程也是基于此 API 实现的。您可以在 vllm-ascend 和 npugraph_ex 代码仓库中查看更多使用此 API 的示例。

DFX#

通过复用 PyTorch 社区的 TORCH_COMPILE_DEBUG 环境变量,当设置 TORCH_COMPILE_DEBUG=1 时,将输出整个过程中的 FX 图。