Npugraph_ex#

它是如何工作的?#

这是一种基于 FX 图的优化,可视为 aclgraph 模式的一种加速方案。

你可以获取其代码 code

默认 FX 图优化#

FX 图 pass#

  • 对于模型的中间节点,将节点中包含的非就地操作符替换为就地操作符,以减少计算过程中的内存移动并提升性能。

  • 对于模型的原始输入参数,如果它们包含就地操作符,Dynamo 的 Functionalize 过程会将就地操作符替换为非就地操作符 + 拷贝操作符的形式。npugraph_ex 会逆转此过程,恢复就地操作符并减少内存移动。

FX 融合 pass#

npugraph_ex 目前提供了一些算子融合 pass,未来将添加更多。

符合替换规则的算子组合可以被替换为对应的融合算子。

你可以获取默认的融合 pass 列表

自定义融合 pass#

用户可以在 npugraph_ex 中注册自定义的图融合 pass 来修改 PyTorch FX 图。注册依赖于 register_replacement API。

以下是该 API 的声明及其使用演示。

register_replacement(search_fn, replace_fn, example_inputs, trace_fn=fwd_only, extra_check=_return_true, search_fn_pattern=None)

参数名

输入/输出

说明

是否必需

search_fn

输入

该函数是你在 FX 图中想要识别的算子组合或计算逻辑,例如需要融合的算子组合

replace_fn

输入

当在目标图中找到与 search_fn 对应的组合时,该函数的计算逻辑将替换原始子图,以实现算子融合或优化。

example_inputs

输入

用于跟踪 search_fn 和 replace_fn 的示例输入张量。输入的形状和数据类型应与实际场景匹配。

trace_fn

输入

默认情况下,仅跟踪前向计算图,适用于推理阶段的优化;如果需要支持训练场景,可以提供支持反向跟踪的函数。

extra_check

输入

算子融合后的额外验证函数。该函数的输入参数必须是 torch._inductor.pattern_matcher 中的 Match 对象,用于对匹配结果进行进一步的自定义检查,例如检查融合算子是否在同一流上、检查设备类型、检查输入形状等。

search_fn_pattern

输入

通常无需提供自定义 pattern 对象。其定义遵循原生 PyTorch MultiOutputPattern 对象的规则。传入此参数后,将不再使用 search_fn 来匹配算子组合,而是直接使用此参数作为匹配规则。

使用示例

import functools
import torch, torch_npu, npugraph_ex

from torch._inductor.pattern_matcher import Match
from torch._subclasses.fake_tensor import FakeTensorMode
from npugraph_ex.core.utils import logger

# Assume fusing the add operator and the npu_rms_norm operator into the npu_add_rms_norm operator
# Define a search_fn to find the operator combinations in the original FX graph before fusion.
def search_fn(x1, x2, gamma):
    xOut = torch.add(x1, x2)
    y, _ = torch_npu.npu_rms_norm(xOut, gamma)
    return y, xOut

# Define a replace_fn, that is, a fusion operator, used to replace operator combinations in the FX graph
def replace_fn(x1, x2, gamma):
    y, _, xOut = torch_npu.npu_add_rms_norm(
        x1, x2, gamma
    )
    return y, xOut

# extra_check can pass in additional validation logic. Here, it is used to check whether the last dimension of the first input parameter x1 is a specific value; if it is not the specific value, fusion is not allowed.
def extra_check(match: Match):
    x1 = match.kwargs.get("x1")

    if x1 is None:
        return False 
    if not hasattr(x1, "meta") or "val" not in x1.meta:
        return False

    a_shape = x1.meta["val"].shape
    return a_shape[-1] == 7168 


# Define some sample inputs to trace search_fn and replace_fn into an FX graph
fake_mode = FakeTensorMode()
with fake_mode:
    # sizes/values don't actually matter for initial trace
    # once we get a possible match we re-trace with the actual values and verify the match still holds
    input_tensor = functools.partial(torch.empty, (1, 1, 2), device="npu", dtype=torch.float16)
    kwargs_tensor = functools.partial(torch.empty, 2, device="npu", dtype=torch.float16)

    # Call the npugraph_ex.register_replacement API with search_fn, replace_fn, and example_inputs. If there are additional validations, you can pass them in as extra_check.
    npugraph_ex.register_replacement(
        search_fn=search_fn,
        replace_fn=replace_fn,
        example_inputs=(input_tensor(), input_tensor(), kwargs_tensor()),
        extra_check=extra_check
    )

npugraph_ex 中的默认融合 pass 也是基于此 API 实现的。你可以在 vllm-ascend 和 npugraph_ex 代码仓库中查看更多使用此 API 的示例。

DFX#

通过复用 PyTorch 社区的 TORCH_COMPILE_DEBUG 环境变量,当设置 TORCH_COMPILE_DEBUG=1 时,将输出整个过程中的 FX 图。