Flash Attention 3#

备注

Ascend上的Flash Attention 3目前处于测试阶段。FA3所需的flash_attn_npu包已在GitHub上开源。请参考flash-attention-npu仓库获取更多详情。

本文档介绍如何在vLLM-Ascend中启用Flash Attention 3 (FA3)。FA3为Ascend NPU提供了训练推理一致的注意力实现。

动机#

在veRL等RL训练框架中,训练时的注意力计算使用Flash Attention。当vLLM-Ascend作为推理后端时,默认的Fused Infer Attention (FIA)实现与训练侧的Flash Attention不同,可能导致训练推理不一致。为解决此问题,vLLM-Ascend引入了FA3注意力后端以保持与训练侧的一致性。

FA3在以下场景中至关重要:

  • 训练推理一致性:确保推理时的注意力计算与训练侧一致,这对RL工作流(如veRL)至关重要,因为推理结果用于计算训练信号。

  • 框架调试:一致的注意力实现通过消除训练和推理之间的差异,使问题调试更加容易。

  • 强化学习 (RL):RL训练通常需要确定性和一致的rollout以实现可重复性和稳定训练。

功能对比#

下表比较了GPU FA3和Ascend NPU FA3中flash_attn_with_kvcache的功能:

功能

GPU FA3

NPU FA3

FP16 (float16)

BF16 (bfloat16)

因果注意力

滑动窗口注意力

-

MQA/GQA

分页KV缓存

旋转位置编码 (RoPE)

-

ALiBi

-

-

Softcapping

-

FP8量化

-

变长序列

与GPU实现的差异#

NPU上的flash_attn_with_kvcache接口在API参数上与GPU FA3版本语义一致。主要差异如下:

  1. NPU FA3不支持的功能:滑动窗口注意力、RoPE、ALiBi、Softcapping和FP8量化暂不支持。

  2. 图捕获flash_attn_with_kvcache的tiling在主机侧处理,目前正在优化中。不支持ACL图捕获(即无法捕获到计算图中进行加速)。启用FA3时请使用enforce_eager=True

硬件要求#

FA3目前需要Ascend Atlas A2和A3推理产品NPU。未来将支持其他NPU。

软件要求#

FA3需要flash_attn_npu包,该包提供包含flash_attn_with_kvcache算子的flash_attn_npu_v3模块。

安装#

安装flash_attn_npu wheel包,请参考:MinghuasLab/flash-attention-npu

启用Flash Attention 3#

要启用FA3,您需要:

  1. 设置环境变量export VLLM_BATCH_INVARIANT=1以启用batch invariant模式

  2. 通过LLM参数attention_backend="FLASH_ATTN"将注意力后端指定为FLASH_ATTN

在线推理(服务器模式)#

启动启用FA3的vLLM服务器:

VLLM_BATCH_INVARIANT=1 vllm serve Qwen/Qwen3-8B --attention-backend FLASH_ATTN

然后使用OpenAI兼容客户端:

from openai import OpenAI

client = OpenAI(
    api_key="EMPTY",
    base_url="http://localhost:8000/v1",
)

response = client.completions.create(
    model="Qwen/Qwen3-8B",
    prompt="The future of AI is",
    max_tokens=100,
    temperature=0.7,
    seed=42,
)

print(response.choices[0].text)

离线推理#

使用FA3进行离线批量推理:

import os
os.environ["VLLM_BATCH_INVARIANT"] = "1"

from vllm import LLM, SamplingParams

prompts = [
    "The future of AI is",
    "Machine learning enables",
    "Deep learning models can",
]

sampling_params = SamplingParams(
    temperature=0.7,
    max_tokens=100,
    seed=42,
)

llm = LLM(
    model="Qwen/Qwen3-8B",
    tensor_parallel_size=1,
    attention_backend="FLASH_ATTN",
)

outputs = llm.generate(prompts, sampling_params)

for output in outputs:
    prompt = output.prompt
    generated_text = output.outputs[0].text
    print(f"Prompt: {prompt!r}")
    print(f"Generated: {generated_text!r}\n")

限制#

  • 包尚未开源:FA3所需的flash_attn_npu包尚未发布。在包可用之前,外部用户无法使用FA3。

  • 不支持滑动窗口:FA3不支持滑动窗口注意力。需要滑动窗口的模型需使用默认FIA后端。

  • 不支持ACL图捕获flash_attn_with_kvcache的tiling在主机侧处理,目前不支持ACL图捕获。启用FA3时请使用enforce_eager=True

  • 不支持RoPE:FA3不支持注意力核内的旋转位置编码。vLLM-Ascend通过使用PyTorch原生RoPE回退来修补此问题。

  • 不支持ALiBi:FA3不支持ALiBi(线性偏置注意力)。

  • 不支持Softcapping:FA3不支持注意力logit softcapping。

  • 不支持FP8量化:FA3不支持FP8量化注意力。

  • 不支持MLA和SFA:FA3不支持多头潜在注意力 (MLA) 或稀疏Flash Attention (SFA)。

备注

与默认FIA后端相比,启用FA3可能导致性能下降。这种权衡是有意为之,以保证训练推理一致性。

已测试模型#

FA3已在以下模型上测试验证:

  • Qwen3 (Dense)Qwen/Qwen3-0.6BQwen/Qwen3-1.7BQwen/Qwen3-8B

  • Qwen3 (MoE)Qwen/Qwen3-30B-A3B

其他模型尚未测试,未来将在测试后根据结果决定是否支持。

未来改进#

FA3功能正在积极开发中。计划改进包括:

  • 开源flash_attn_npu

  • 支持ACL图捕获(主机侧tiling优化)

  • 支持更多NPU系列

  • 扩展模型覆盖范围

  • 性能优化

  • 额外测试与验证