FlashAttention: Fast and Memory-Efficient Exact Attention with IO-Awareness

Download PDF

FlashAttention: Fast and Memory-Efficient Exact Attention with IO-Awareness

Metadata

  • Slug: flashattention_2022
  • Year: 2022
  • Venue: NeurIPS
  • Authors: Tri Dao et al.
  • Reading status: read complete
  • Compute regime: Sparse and memory-efficient scaling
  • Primary sources: PDF, extracted text

Compute Setup

The paper is unusually explicit about the device hierarchy that motivates the method. It describes modern GPUs as having a small, fast on-chip SRAM tier and a much larger but slower high-bandwidth memory tier. The A100 example is central: 40-80GB of HBM with 1.5-2.0TB/s bandwidth, plus 192KB of on-chip SRAM per streaming multiprocessor. The paper's Figure 1 also frames the contrast as SRAM bandwidth on the order of 19TB/s versus A100 HBM at about 1.5TB/s. FlashAttention is therefore not presented as an abstract attention variant; it is an exact attention implementation shaped around the specific gap between GPU compute rate, HBM bandwidth, and SRAM capacity.

The benchmark hardware is also listed. BERT-large training follows the MLPerf 1.1 reference setup on 8xA100-80GB GPUs with FP16 training through Apex AMP. GPT-2 small and medium training uses 8xA100-40GB GPUs, effective batch size 512, gradient accumulation to fit memory, mixed precision, AdamW, and 400K steps. Kernel comparisons include an A100-SXM4-40GB GPU, one A100 GPU with 40GB HBM, RTX 3090, and T4. The paper explicitly notes that speedup varies by GPU generation because HBM bandwidth and SRAM size change.

Bottleneck

The bottleneck is IO, especially HBM traffic, not just the quadratic FLOP count of attention. Standard attention materializes the score matrix S = QK^T and the probability matrix P = softmax(S). Those N x N intermediates are written to HBM and then read back for later stages and for backward computation. The paper gives standard attention an HBM-access complexity of Theta(Nd + N^2), while FlashAttention reduces it to Theta(N^2 d^2 M^-1), where M is SRAM size. That framing matters historically because many approximate or sparse attention methods reduced FLOPs but did not necessarily reduce wall-clock time on real accelerators.

This makes sequence length a memory-system problem. At long context, the attention matrix can dominate both memory footprint and HBM bandwidth. On short BERT-like sequences, storing attention can be tolerable; on GPT-style and long-document workloads, writing and rereading N^2 data becomes the cost that prevents the GPU from turning its arithmetic throughput into useful training throughput.

Method Adaptation

FlashAttention adapts exact attention to the GPU memory hierarchy by tiling Q, K, and V. Blocks of K and V are moved from HBM into SRAM, blocks of Q are streamed through them, and the output is accumulated without ever storing the full attention matrix in HBM. The softmax is computed online using per-row running statistics, so exact normalization is preserved across tiles. The forward pass writes only the output and compact normalization metadata, not S and P.

The backward pass makes a deliberate memory/compute tradeoff: it recomputes attention blocks from Q, K, and V rather than saving the full attention matrix during the forward pass. This can add FLOPs, but it removes large HBM writes and reads, which the experiments identify as the wall-clock limiter. The implementation fuses the matrix multiply, masking, softmax, dropout, and second matrix multiply into a CUDA kernel so intermediates stay in registers/SRAM rather than bouncing through HBM. Block sizes are hardware-tuned: larger head dimensions need smaller tiles to fit SRAM, and T4 uses smaller blocks than A100 because its SRAM is smaller.

The block-sparse extension keeps the same IO-aware logic. Sparsity only helps wall-clock time when the implementation avoids touching unnecessary blocks in HBM; otherwise theoretical sparsity can fail to translate into speed.

Evidence

The training evidence is end-to-end, not only microbenchmarks. BERT-large, starting from the same MLPerf initialization and targeting 72.0% masked-language-model validation accuracy, trains in 17.4 +/- 1.4 minutes on 8xA100 GPUs, compared with the cited NVIDIA MLPerf 1.1 result of 20.0 +/- 1.5 minutes. GPT-2 small on OpenWebText trains to the same reported perplexity in 2.7 days with FlashAttention, versus 9.5 days for HuggingFace and 4.7 days for Megatron-LM. GPT-2 medium trains in 6.9 days versus 21.0 days for HuggingFace and 11.5 days for Megatron-LM.

The long-context result connects compute structure to modeling quality. GPT-2 small with FlashAttention at 4K context trains in 3.6 days on 8xA100 GPUs, still faster than the 1K-context Megatron-LM baseline at 4.7 days, while improving OpenWebText perplexity from 18.2 to 17.5. On Long Range Arena, the paper reports a 2.4x speedup overall and shows that FlashAttention enables Transformer results on Path-X and Path-256 that were previously blocked by long exact attention cost. In kernel profiling, standard attention reads/writes far more HBM, and the paper reports up to 7.6x attention-computation speedup on GPT-2.

Historical Effect

FlashAttention made memory hierarchy a first-class design axis for Transformer scaling. Earlier attention-efficiency work often emphasized asymptotic FLOPs or sparsity patterns. This paper argued that, on GPUs, the decisive question is whether the algorithm moves data through HBM more than necessary. It helped shift long-context Transformer engineering toward IO-aware kernels, recomputation, fusion, and tile sizing tied to the accelerator.

Limits

The speedup is hardware- and configuration-dependent. A100, RTX 3090, and T4 show different gains because HBM bandwidth, SRAM size, head dimension, dropout, masking, and sequence length all matter. The backward pass can be slower than some fused short-sequence kernels because FlashAttention recomputes instead of storing the attention matrix. FMHA remains competitive at short BERT sequence lengths, though it does not provide the same long-sequence memory savings. The paper supports exact attention up to much longer sequences, but block size, registers, SRAM, and kernel implementation still bound what fits efficiently.

Links