FlashAttention: Fast and Memory-Efficient Exact Attention with IO-Awareness

下载 PDF

FlashAttention: Fast and Memory-Efficient Exact Attention with IO-Awareness - 中文验证版

英文原始依据卡片:flashattention_2022.md

状态:已翻译。

元数据

  • Slug: flashattention_2022
  • 年份: 2022
  • 会议: NeurIPS
  • 作者: Tri Dao et al.
  • 阅读状态: read complete
  • 计算范式: 稀疏化与内存高效扩展
  • 主要来源: PDF抽取文本

计算设置

论文异常明确地描述了驱动该方法的设备层级:现代 GPU 具有小而快的 on-chip SRAM 层和更大但较慢的 high-bandwidth memory 层。A100 示例是核心:40-80GB HBM、1.5-2.0TB/s bandwidth,以及每个 streaming multiprocessor 192KB on-chip SRAM。论文的 Figure 1 还将对比表述为 SRAM bandwidth 约 19TB/s 对比 A100 HBM 约 1.5TB/s。因此,FlashAttention 不是作为抽象的 attention 变体提出的,而是一种围绕 GPU 计算速率、HBM bandwidth 和 SRAM 容量之间具体差距而设计的 exact attention 实现。

Benchmark 硬件也有列出:BERT-large 训练遵循 MLPerf 1.1 reference setup,使用 8xA100-80GB GPUs、FP16 training(通过 Apex AMP)。GPT-2 small 和 medium 训练使用 8xA100-40GB GPUs、effective batch size 512、gradient accumulation 以适配内存、mixed precision、AdamW 和 400K steps。Kernel 比较包括 A100-SXM4-40GB GPU、一块 40GB HBM 的 A100 GPU、RTX 3090 和 T4。论文明确指出加速比因 GPU 代际而异,因为 HBM bandwidth 和 SRAM size 会变化。

瓶颈

瓶颈是 IO,尤其是 HBM 流量,而不仅仅是 attention 的二次 FLOP 数量。标准 attention 将 score matrix S = QK^T 和 probability matrix P = softmax(S) 物化。这些 N x N 中间结果被写入 HBM,然后在后续阶段和 backward 计算中再次读回。论文给出标准 attention 的 HBM-access 复杂度为 Θ(Nd + N²),而 FlashAttention 将其降低到 Θ(N²d²M⁻¹),其中 M 是 SRAM size。这一表述在历史上很重要,因为许多近似或稀疏 attention 方法降低了 FLOPs,但不一定在实际加速器上减少了 wall-clock time。

这使得 sequence length 成为一个 memory-system 问题。在长 context 下,attention matrix 可能同时主导内存占用和 HBM bandwidth。在类似 BERT 的短序列上,存储 attention 可能是可容忍的;在 GPT 风格和长文档 workloads 中,写入和重新读取 N² 数据成为阻止 GPU 将其算术吞吐量转化为有用训练吞吐量的成本。

方法适配

FlashAttention 通过以下方式把 exact attention 适配到 GPU memory hierarchy:

  • 通过 on-chip SRAM 对 Q、K 和 V 做 tiling:K 和 V 的 blocks 从 HBM 移入 SRAM,Q 的 blocks 流式经过它们,输出累积而不在 HBM 中存储完整 attention matrix。
  • 在线计算 softmax:使用逐行运行统计量,使得 exact normalization 跨 tile 保持不变。
  • Forward pass 仅写入输出和紧凑的 normalization 元数据,不写入 S 和 P。
  • Backward pass 做 deliberate memory/compute tradeoff:从 Q、K、V recompute attention blocks,而非在 forward pass 中存储完整 attention matrix。这会增加 FLOPs,但去除了大量 HBM 写入和读取,实验表明这些是 wall-clock 的制约因素。
  • 将操作融合为一个 CUDA kernel:matrix multiply、masking、softmax、dropout 和第二个 matrix multiply 融合在一起,intermediates 保留在 registers/SRAM 中而非经 HBM 传输。
  • Block sizes 按硬件调优:较大的 head dimension 需要较小的 tiles 以适配 SRAM,T4 使用比 A100 更小的 blocks,因为其 SRAM 更小。
  • 将想法扩展到 block-sparse variant,保持相同的 IO-aware 逻辑。稀疏性仅在实现避免访问 HBM 中不必要的 blocks 时才能转化为 wall-clock 加速。

证据

训练证据是端到端的,不仅是 microbenchmarks。BERT-large 从相同的 MLPerf initialization 开始,目标为 72.0% masked-language-model validation accuracy,在 8xA100 GPUs 上训练时间为 17.4 +/- 1.4 分钟,对比引用的 NVIDIA MLPerf 1.1 结果为 20.0 +/- 1.5 分钟。GPT-2 small 在 OpenWebText 上使用 FlashAttention 训练到相同 reported perplexity 仅需 2.7 天,对比 HuggingFace 的 9.5 天和 Megatron-LM 的 4.7 天。GPT-2 medium 训练时间为 6.9 天,对比 HuggingFace 的 21.0 天和 Megatron-LM 的 11.5 天。

长 context 结果将 compute structure 与建模质量联系起来:使用 FlashAttention 的 GPT-2 small 在 4K context 下训练时间为 3.6 天(8xA100 GPUs),仍快于 1K-context Megatron-LM baseline 的 4.7 天,同时将 OpenWebText perplexity 从 18.2 提升到 17.5。在 Long Range Arena 上,论文报告总体 2.4x speedup,并展示 FlashAttention 使 Transformer 能够在 Path-X 和 Path-256 上取得结果,这些此前因长 exact attention 成本而无法实现。在 kernel profiling 中,标准 attention 读/写远更多的 HBM,论文报告在 GPT-2 上达到最高 7.6x attention-computation speedup。

历史影响

FlashAttention 将 memory hierarchy 确立为 Transformer scaling 的一等设计轴。此前的 attention 效率研究通常强调渐近 FLOPs 或稀疏模式。这篇论文论证了,在 GPU 上,决定性问题是算法是否在 HBM 中移动了超过必要的数据。它帮助将长 context Transformer 工程转向 IO-aware kernels、recomputation、fusion 以及与加速器绑定的 tile sizing。

局限

  • 加速依赖硬件和配置。
  • A100、RTX 3090 和 T4 表现出不同收益,因为 HBM bandwidth、SRAM size、head dimension、dropout、masking 和 sequence length 都起作用。
  • Backward pass 可能比某些 fused short-sequence kernel 更慢,因为 FlashAttention recompute 而非存储 attention matrix。
  • FMHA 在 BERT 短序列长度上仍相当,但不提供相同的长序列内存节省。
  • Block sizes 受 SRAM、registers 和 kernel implementation details 约束。

链接