FlashAttention-2: Faster Attention with Better Parallelism and Work Partitioning

Download PDF

FlashAttention-2: Faster Attention with Better Parallelism and Work Partitioning

Metadata

  • Slug: flashattention2_2023
  • Year: 2023
  • Venue: arXiv
  • Authors: Tri Dao
  • Reading status: read complete
  • Compute regime: Sparse and memory-efficient scaling
  • Primary sources: PDF, extracted text

Compute Setup

The paper explicitly benchmarks NVIDIA A100 and H100 GPUs. The hardware background section uses A100 as the main example: 40-80GB HBM with 1.5-2.0TB/s bandwidth, 192KB on-chip SRAM per streaming multiprocessor, and 108 streaming multiprocessors. It estimates shared-memory bandwidth around 19TB/s and focuses on HBM and SRAM because L2 is not directly programmer-controlled. It also states the A100 contrast between 312 TFLOPs/s FP16/BF16 matmul and 19.5 TFLOPs/s non-matmul FP32.

The attention benchmarks are measured on an A100 80GB SXM4 GPU across sequence lengths 512 through 16k, with batch size chosen so total tokens are 16k, hidden dimension 2048, and head dimension 64 or 128. H100 results use an H100 80GB SXM5, but the implementation does not use H100-specific features such as TMA, fourth-generation Tensor Cores, or FP8. End-to-end GPT-style training is measured on 8xA100 80GB SXM GPUs for 1.3B and 2.7B parameter models at 2k and 8k context lengths.

Bottleneck

Standard attention is bottlenecked by quadratic memory traffic. It materializes the score and probability matrices in HBM, requiring O(N^2) memory and repeated HBM reads/writes. For typical long-context settings where sequence length is 1k-8k and head dimension is 64-128, the memory traffic dominates wall-clock time.

FlashAttention-1 fixed much of the HBM problem by tiling attention through SRAM and avoiding materializing the full score/probability matrices, but it still did not behave like a highly optimized GEMM. The paper reports that FlashAttention forward reaches only 30-50% of theoretical max FLOPs/s, and backward only 25-35% on A100, whereas optimized GEMM can reach 80-90%. The remaining bottleneck is compute-structure mismatch: too many non-matmul operations for Tensor Core hardware, low occupancy when batch size or head count is small, and unnecessary shared-memory communication between warps.

Method Adaptation

FlashAttention-2 keeps exact attention and the IO-aware tiling/recomputation idea, but changes how the work maps to the GPU. First, it reduces non-matmul FLOPs. That is hardware-critical because a non-matmul FLOP can be far more expensive on A100 than a Tensor Core matmul FLOP. The goal is to spend more time in the fast matmul units and less in scalar FP32-style work.

Second, it parallelizes over sequence length in addition to batch and heads. FlashAttention-1 assigns one thread block to an attention head and therefore gets only batch size times number of heads blocks. For long sequences, batch size is often small; adding sequence-dimension parallelism increases occupancy and lets more SMs do useful work.

Third, it changes warp partitioning inside a thread block to reduce shared-memory reads/writes and synchronization. The kernel-level design is therefore not just "less memory" but "better work placement": which thread blocks own which sequence chunks, how warps communicate, and what block sizes fit registers and shared memory without spilling. The paper notes that larger blocks reduce shared-memory traffic but can exceed register/shared-memory budgets, so tuning depends on head dimension and device shared memory.

Evidence

The paper reports attention microbenchmarks where FlashAttention-2 is 1.7-3.0x faster than FlashAttention, 1.3-2.5x faster than FlashAttention in Triton, and 3-10x faster than standard PyTorch attention. On A100, it reaches up to 230 TFLOPs/s in attention forward and 73% of theoretical maximum. It reaches up to 63% of theoretical maximum in backward. These numbers matter because the claim is not only asymptotic memory reduction; it is moving attention closer to GEMM utilization on real accelerator kernels.

The H100 result is also framed carefully. Running the same implementation without new H100 instructions reaches up to 335 TFLOPs/s, and the paper leaves another expected 1.5-2x speedup from TMA, fourth-generation Tensor Cores, and FP8 as future work. That limit is important: the reported H100 number is not a fully H100-specialized kernel.

For end-to-end training on 8xA100 80GB SXM, Table 1 reports GPT3-1.3B at 2k context improving from 142 TFLOPs/s/GPU without FlashAttention to 189 with FlashAttention and 196 with FlashAttention-2. At 8k context, GPT3-1.3B improves from 72 to 170 to 220. GPT3-2.7B at 2k goes 149 to 189 to 205, and GPT3-2.7B at 8k goes 80 to 175 to 225. The top result is 225 TFLOPs/s per A100 GPU, corresponding to 72% model FLOPs utilization.

Historical Effect

FlashAttention-2 marks a shift from algorithmic attention complexity alone to kernel-level accelerator fit. FlashAttention-1 made exact long-context attention practical by respecting the GPU memory hierarchy; FlashAttention-2 showed that the next bottleneck was parallelism and non-matmul overhead. That helped make 8k-16k context lengths cheaper as a standard training primitive and put attention kernels on a path closer to GEMM-like efficiency.

Limits

The method is kernel- and device-specific. It depends on GPU memory hierarchy, Tensor Core behavior, shared-memory capacity, register pressure, and block-size tuning. The paper's H100 results do not yet exploit the newer H100 features it names. FlashAttention-2 reduces the attention bottleneck, but full training throughput still depends on the rest of the Transformer stack, distributed parallelism, optimizer state, and data loading.

Links