来源材料

资料来源

← 首页

                                                              FlashAttention-2:
                                          Faster Attention with Better Parallelism and Work Partitioning
                                                                                              Tri Dao1,2
                                                                1 Department of Computer Science, Princeton University
                                                                 2 Department of Computer Science, Stanford University




arXiv:2307.08691v1 [cs.LG] 17 Jul 2023
                                                                                     trid@cs.stanford.edu

                                                                                            July 18, 2023

                                                                                               Abstract
                                                    Scaling Transformers to longer sequence lengths has been a major problem in the last several years,
                                                promising to improve performance in language modeling and high-resolution image understanding, as
                                                well as to unlock new applications in code, audio, and video generation. The attention layer is the
                                                main bottleneck in scaling to longer sequences, as its runtime and memory increase quadratically in
                                                the sequence length. FlashAttention [5] exploits the asymmetric GPU memory hierarchy to bring
                                                significant memory saving (linear instead of quadratic) and runtime speedup (2-4× compared to optimized
                                                baselines), with no approximation. However, FlashAttention is still not nearly as fast as optimized
                                                matrix-multiply (GEMM) operations, reaching only 25-40% of the theoretical maximum FLOPs/s. We
                                                observe that the inefficiency is due to suboptimal work partitioning between different thread blocks and
                                                warps on the GPU, causing either low-occupancy or unnecessary shared memory reads/writes. We propose
                                                FlashAttention-2, with better work partitioning to address these issues. In particular, we (1) tweak
                                                the algorithm to reduce the number of non-matmul FLOPs (2) parallelize the attention computation, even
                                                for a single head, across different thread blocks to increase occupancy, and (3) within each thread block,
                                                distribute the work between warps to reduce communication through shared memory. These yield around
                                                2× speedup compared to FlashAttention, reaching 50-73% of the theoretical maximum FLOPs/s on
                                                A100 and getting close to the efficiency of GEMM operations. We empirically validate that when used
                                                end-to-end to train GPT-style models, FlashAttention-2 reaches training speed of up to 225 TFLOPs/s
                                                per A100 GPU (72% model FLOPs utilization).1


                                         1      Introduction
                                         Scaling up the context length of Transformers [18] is a challenge, since the attention layer at their heart
                                         has runtime and memory requirements quadratic in the input sequence length. Ideally, we would like to go
                                         beyond the standard 2k sequence length limit to train models to understand books, high resolution images,
                                         and long-form videos. Just within the last year, there have been several language models with much longer
                                         context than before: GPT-4 [12] with context length 32k, MosaicML’s MPT with context length 65k, and
                                         Anthropic’s Claude with context length 100k. Emerging use cases such as long document querying and story
                                         writing have demonstrated a need for models with such long context.
                                             To reduce the computational requirement of attention on such long context, there have been numerous
                                         methods proposed to approximate attention [2, 3, 4, 8, 9, 14, 19, 20]. Though these methods have seen
                                         some use cases, as far as we know, most large-scale training runs still use standard attention. Motivated by
                                         this, Dao et al. [5] proposed to reorder the attention computation and leverages classical techniques (tiling,
                                         recomputation) to significantly speed it up and reduce memory usage from quadratic to linear in sequence
                                         length. This yields 2-4× wall-clock time speedup over optimized baselines, up to 10-20× memory saving,
                                             1 FlashAttention-2 is available at https://github.com/Dao-AILab/flash-attention




                                                                                                    1
with no approximation, and as a result FlashAttention has seen wide adoption in large-scale training and
inference of Transformers.
     However, context length increases even more, FlashAttention is still not nearly as efficient as other
primitives such as matrix-multiply (GEMM). In particular, while FlashAttention is already 2-4× faster
than a standard attention implementation, the forward pass only reaches 30-50% of the theoretical maximum
FLOPs/s of the device (Fig. 5), while the backward pass is even more challenging, reaching only 25-35%
of maximum throughput on A100 GPU (Fig. 6). In contrast, optimized GEMM can reach up to 80-90% of
the theoretical maximum device throughput. Through careful profiling, we observe that FlashAttention
still has suboptimal work partitioning between different thread blocks and warps on the GPU, causing either
low-occupancy or unnecessary shared memory reads/writes.
     Building on FlashAttention, we propose FlashAttention-2 with better parallelism and work
partitioning to address these challenges.

    1. In Section 3.1, we tweak the algorithms to reduce the number of non-matmul FLOPs while not changing
       the output. While the non-matmul FLOPs only account for a small fraction of the total FLOPs, they
       take longer to perform as GPUs have specialized units for matrix multiply, and as a result the matmul
       throughput can be up to 16× higher than non-matmul throughput. It is thus important to reduce
       non-matmul FLOPs and spend as much time as possible doing matmul FLOPs.
    2. We propose to parallelize both the forward pass and backward pass along the sequence length dimension,
       in addition to the batch and number of heads dimension. This increases occupancy (utilization of GPU
       resources) in the case where the sequences are long (and hence batch size is often small).
    3. Even within one block of attention computation, we partition the work between different warps of a
       thread block to reduce communication and shared memory reads/writes.

    In Section 4, we empirically validate that FlashAttention-2 yields significant speedup compared to
even FlashAttention. Benchmarks on different settings (with or without causal mask, different head
dimensions) show that FlashAttention-2 achieves around 2× speedup over FlashAttention, reaching
up to 73% of the theoretical max throughput in the forward pass, and up to 63% of the theoretical max
throughput in the backward pass. When used end-to-end to train GPT-style models, we reach training speed
of up to 225 TFLOPs/s per A100 GPU.


2     Background
We provide some background on the performance characteristics and execution model of GPUs. We also
describe the standard implementation of attention, as well as FlashAttention.

2.1     Hardware characteristics
GPU performance characteristics. The GPU consists of compute elements (e.g., floating point arithmetic
units) and a memory hierarchy. Most modern GPUs contain specialized units to accelerate matrix multiply in
low-precision (e.g., Tensor Cores on Nvidia GPUs for FP16/BF16 matrix multiply). The memory hierarchy
comprise of high bandwidth memory (HBM), and on-chip SRAM (aka shared memory). As an example, the
A100 GPU has 40-80GB of high bandwidth memory (HBM) with bandwidth 1.5-2.0TB/s and 192KB of
on-chip SRAM per each of 108 streaming multiprocessors with bandwidth estimated around 19TB/s [6, 7].
As the L2 cache is not directly controllable by the programmer, we focus on the HBM and SRAM for the
purpose of this discussion.
    Execution Model. GPUs have a massive number of threads to execute an operation (called a kernel).
Threads are organized into thread blocks, which are scheduled to run on streaming multiprocessors (SMs).
Within each thread blocks, threads are grouped into warps (a group of 32 threads). Threads within a warp
can communicate by fast shuffle instructions or cooperate to perform matrix multiply. Warps within a thread
block can communicate by reading from / writing to shared memory. Each kernel loads inputs from HBM to
registers and SRAM, computes, then writes outputs to HBM.



                                                      2
2.2     Standard Attention Implementation
Given input sequences Q, K, V ∈ R 𝑁 ×𝑑 where 𝑁 is the sequence length and 𝑑 is the head dimension, we want
to compute the attention output O ∈ R 𝑁 ×𝑑 :

                         S = QK⊤ ∈ R 𝑁 × 𝑁 ,     P = softmax(S) ∈ R 𝑁 × 𝑁 ,     O = PV ∈ R 𝑁 ×𝑑 ,

where softmax is applied row-wise.2 For multi-head attention (MHA), this same computation is performed in
parallel across many heads, and parallel over the batch dimension (number of input sequences in a batch).
   The backward pass of attention proceeds as follows. Let dO ∈ R 𝑁 ×𝑑 be the gradient of O with respect to
some loss function. Then by the chain rule (aka backpropagation):

                                               dV = P⊤ dO ∈ R 𝑁 ×𝑑
                                               dP = dOV⊤ ∈ R 𝑁 × 𝑁
                                               dS = dsoftmax(dP) ∈ R 𝑁 × 𝑁
                                               dQ = dSK ∈ R 𝑁 ×𝑑
                                               dK = QdS⊤ ∈ R 𝑁 ×𝑑 ,

where dsoftmax is the gradient (backward pass) of softmax applied row-wise. One can work out that if 𝑝 =
softmax(𝑠) for some vector 𝑠 and 𝑝, then with output gradient 𝑑𝑝, the input gradient 𝑑𝑠 = (diag( 𝑝) − 𝑝 𝑝 ⊤ )𝑑𝑝.
    Standard attention implementations materialize the matrices S and P to HBM, which takes 𝑂 (𝑁 2 )
memory. Often 𝑁 ≫ 𝑑 (typically 𝑁 is on the order of 1k–8k and 𝑑 is around 64–128). The standard attention
implementation (1) calls the matrix multiply (GEMM) subroutine to multiply S = QK⊤ , writes the result to
HBM, then (2) loads § from HBM to compute softmax and write the result P to HBM, and finally (3) calls
GEMM to get O = PV. As most of the operations are bounded by memory bandwidth, the large number of
memory accesses translates to slow wall-clock time. Moreover, the required memory is 𝑂 (𝑁 2 ) due to having
to materialize S and P. Moreover, one has to save P ∈ R 𝑁 × 𝑁 for the backward pass to compute the gradients.

2.3     FlashAttention
To speed up attention on hardware accelerators such as GPU, [5] proposes an algorithm to reduce the memory
reads/writes while maintaining the same output (without approximation).

2.3.1    Forward pass
FlashAttention applies the classical technique of tiling to reduce memory IOs, by (1) loading blocks of
inputs from HBM to SRAM, (2) computing attention with respect to that block, and then (3) updating the
output without writing the large intermediate matrices S and P to HBM. As the softmax couples entire rows
or blocks of row, online softmax [11, 13] can split the attention computation into blocks, and rescale the
output of each block to finally get the right result (with no approximation). By significantly reducing the
amount of memory reads/writes, FlashAttention yields 2-4× wall-clock speedup over optimized baseline
attention implementations.
    We describe the online softmax technique [11] and how it is used in attention [13]. For simplicity, consider
just one row block of the attention matrix S, of the form S ( 1 ) S ( 2 ) for some matrices S ( 1 ) , S ( 2 ) ∈ R 𝐵𝑟 ×𝐵𝑐 ,
where 𝐵𝑟 and 𝐵𝑐 are the row and column ( 1 ) block sizes. We want to compute softmax of this row block and
                                       V
multiply with the value, of the form            for some matrices V ( 1 ) , V ( 2 ) ∈ R 𝐵𝑐 ×𝑑 . Standard softmax would
                                       V (2)
   2 For clarity of exposition, we omit the scaling of QK⊤ (typically by 1/d), and optionally elementwise masking on S and/or

dropout applied to P




                                                             3
compute:

                           𝑚 = max(rowmax(S ( 1 ) ), rowmax(S ( 2 ) )) ∈ R 𝐵𝑟
                                                              ( 1)                            ( 2)
                           ℓ = rowsum(𝑒 S −𝑚 ) + rowsum(𝑒 S −𝑚 ) ∈ R 𝐵𝑟
                                                               h                    i
                                                                  ( 1)         ( 2)
                           P = P ( 1 ) P ( 2 ) = diag(ℓ) −1 𝑒 S −𝑚 𝑒 S −𝑚 ∈ R 𝐵𝑟 ×2 𝐵𝑐
                                             

                                               V (1)
                                                       
                                                                          ( 1)             ( 2)
                           O = P (1) P (2)                = diag(ℓ) −1 𝑒 S −𝑚 V ( 1 ) + 𝑒 S −𝑚 V ( 2 ) ∈ R 𝐵𝑟 ×𝑑 .
                               
                                                  V (2)

Online softmax instead computes “local” softmax with respect to each block and rescale to get the right
output at the end:

      𝑚 ( 1 ) = rowmax(S ( 1 ) ) ∈ R 𝐵𝑟
                                       ( 1) −𝑚 ( 1)
       ℓ ( 1 ) = rowsum(𝑒 S                           ) ∈ R 𝐵𝑟
                                             ( 1 ) −𝑚 ( 1 )
      P̃ ( 1 ) = diag(ℓ ( 1 ) ) −1 𝑒 S                         ∈ R 𝐵𝑟 ×𝐵𝑐
                                                                     ( 1 ) −𝑚 ( 1)
      O ( 1 ) = P̃ ( 1 ) V ( 1 ) = diag(ℓ ( 1 ) ) −1 𝑒 S                             V ( 1 ) ∈ R 𝐵𝑟 ×𝑑
      𝑚 ( 2 ) = max(𝑚 ( 1 ) , rowmax(S ( 2 ) )) = 𝑚
                     ( 1 ) −𝑚 ( 2 )                                   ( 2 ) −𝑚 ( 2 )                     ( 1) −𝑚                       ( 2) −𝑚
       ℓ (2) = 𝑒 𝑚                    ℓ ( 1 ) + rowsum(𝑒 S                             ) = rowsum(𝑒 S              ) + rowsum(𝑒 S                )=ℓ
                                           S ( 2) −𝑚 ( 2)
      P̃ ( 2 ) = diag(ℓ ( 2 ) ) −1 𝑒
                                                                                                          ( 1) −𝑚                                     ( 2) −𝑚
      O ( 2 ) = diag(ℓ ( 1 ) /ℓ ( 2 ) ) −1 O ( 1 ) + P̃ ( 2 ) V ( 2 ) = diag(ℓ ( 2 ) ) −1 𝑒 𝑠                       V ( 1 ) + diag(ℓ ( 2 ) ) −1 𝑒 𝑠             V ( 2 ) = O.

   We show how FlashAttention uses online softmax to enable tiling (Fig. 1) to reduce memory reads/writes.




Figure 1: Diagram of how FlashAttention forward pass is performed, when the key K is partitioned into
two blocks and the value V is also partitioned into two blocks. By computing attention with respect to
each block and rescaling the output, we get the right answer at the end, while avoiding expensive memory
reads/writes of the intermediate matrices S and P. We simplify the diagram, omitting the step in softmax
that subtracts each element by the row-wise max.




                                                                                               4
2.3.2     Backward pass
In the backward pass, by re-computing the values of the attention matrices S and P once blocks of inputs
Q, K, V are already loaded to SRAM, FlashAttention avoids having to store large intermediate values. By
not having to save the large matrices S and P of size 𝑁 × 𝑁, FlashAttention yields 10-20× memory saving
depending on sequence length (memory required in linear in sequence length 𝑁 instead of quadratic). The
backward pass also achieves 2-4× wall-clock speedup due to reduce memory reads/writes.
    The backward pass applies tiling to the equations in Section 2.2. Though the backward pass is simpler
than the forward pass conceptually (there is no softmax rescaling), the implementation is significantly more
involved. This is because there are more values to be kept in SRAM to perform 5 matrix multiples in the
backward pass, compared to just 2 matrix multiples in the forward pass.


3       FlashAttention-2: Algorithm, Parallelism, and Work Partition-
        ing
We describe the FlashAttention-2 algorithm, which includes several tweaks to FlashAttention to reduce
the number of non-matmul FLOPs. We then describe how to parallelize the computation on different thread
blocks to make full use the GPU resources. Finally we describe we partition the work between different warps
within one thread block to reduce the amount of shared memory access. These improvements lead to 2-3×
speedup as validated in Section 4.

3.1      Algorithm
We tweak the algorithm from FlashAttention to reduce the number of non-matmul FLOPs. This is
because modern GPUs have specialized compute units (e.g., Tensor Cores on Nvidia GPUs) that makes
matmul much faster. As an example, the A100 GPU has a max theoretical throughput of 312 TFLOPs/s of
FP16/BF16 matmul, but only 19.5 TFLOPs/s of non-matmul FP32. Another way to think about this is that
each non-matmul FLOP is 16× more expensive than a matmul FLOP. To maintain high throughput (e.g.,
more than 50% of the maximum theoretical TFLOPs/s), we want to spend as much time on matmul FLOPs
as possible.

3.1.1     Forward pass
We revisit the online softmax trick as shown in Section 2.3 and make two minor tweaks to reduce non-matmul
FLOPs:
    1. We do not have to rescale both terms of the output update by diag(ℓ ( 2 ) ) −1 :
                                                                                                              ( 2) −𝑚 ( 2)
                                  O ( 2 ) = diag(ℓ ( 1 ) /ℓ ( 2 ) ) −1 O ( 1 ) + diag(ℓ ( 2 ) ) −1 𝑒 S                       V (2) .

        We can instead maintain an “un-scaled” version of O ( 2 ) and keep around the statistics ℓ ( 2 ) :
                                                                                           ( 2 ) −𝑚 ( 2 )
                                              Õ ( 2 ) = diag(ℓ ( 1 ) ) −1 O ( 1 ) + 𝑒 S                    V (2) .

        Only at the every end of the loop do we scale the final Õ ( last ) by diag(ℓ ( last ) ) −1 to get the right output.
    2. We do not have to save both the max 𝑚 ( 𝑗 ) and the sum of exponentials ℓ ( 𝑗 ) for the backward pass. We
       only need to store the logsumexp 𝐿 ( 𝑗 ) = 𝑚 ( 𝑗 ) + log(ℓ ( 𝑗 ) ).




                                                                       5
      In the simple case of 2 blocks in Section 2.3, the online softmax trick now becomes:

                𝑚 ( 1 ) = rowmax(S ( 1 ) ) ∈ R 𝐵𝑟
                                                 ( 1)        ( 1)
                ℓ ( 1 ) = rowsum(𝑒 S −𝑚 ) ∈ R 𝐵𝑟
                O (˜1 ) = 𝑒 S −𝑚 V ( 1 ) ∈ R 𝐵𝑟 ×𝑑
                             ( 1) ( 1)



                𝑚 ( 2 ) = max(𝑚 ( 1 ) , rowmax(S ( 2 ) )) = 𝑚
                               ( 1 ) −𝑚 ( 2)                                         ( 2) −𝑚 ( 2)                                   ( 1 ) −𝑚                         ( 2 ) −𝑚
                 ℓ (2) = 𝑒 𝑚                   ℓ ( 1 ) + rowsum(𝑒 S                                 ) = rowsum(𝑒 S                             ) + rowsum(𝑒 S                   )=ℓ
                     (2)               ( 2 ) − 1 S ( 2) −𝑚 ( 2)
                P̃         = diag(ℓ        )          𝑒
                                       𝑚 ( 1) −𝑚 ( 2)                                ( 2) −𝑚 ( 2)                        ( 1 ) −𝑚                 ( 2 ) −𝑚
                Õ ( 2 ) = diag(𝑒                         ) −1 Õ ( 1 ) + 𝑒 S                       V (2) = 𝑒 𝑠                     V (1) + 𝑒 𝑠              V (2)
                O ( 2 ) = diag(ℓ ( 2 ) ) −1 Õ ( 2 ) = O.

      We describe the full FlashAttention-2 forward pass in Algorithm 1.

Algorithm 1 FlashAttention-2 forward pass
Require: Matrices Q, K,      l Vm ∈ R
                                           𝑁 ×𝑑 in HBM, block sizes 𝐵 , 𝐵 .
                                                                             𝑐   𝑟                               l m
 1: Divide Q into 𝑇𝑟 = 𝐵        𝑁
                                      blocks   Q 1 , . . . , Q𝑇𝑟 of size 𝐵 𝑟 × 𝑑 each, and divide K, V in to 𝑇𝑐 = 𝐵𝑁𝑐 blocks
                                 𝑟
    K1 , . . . , K𝑇𝑐 and V1 , . . . , V𝑇𝑐 , of size 𝐵𝑐 × 𝑑 each.
 2: Divide the output O ∈ R 𝑁 ×𝑑 into 𝑇𝑟 blocks O𝑖 , . . . , O𝑇𝑟 of size 𝐵𝑟 × 𝑑 each, and divide the logsumexp 𝐿
    into 𝑇𝑟 blocks 𝐿 𝑖 , . . . , 𝐿 𝑇𝑟 of size 𝐵𝑟 each.
 3: for 1 ≤ 𝑖 ≤ 𝑇𝑟 do
 4:    Load Q𝑖 from HBM to on-chip SRAM.
 5:    On chip, initialize O𝑖( 0 ) = (0) 𝐵𝑟 ×𝑑 ∈ R 𝐵𝑟 ×𝑑 , ℓ𝑖( 0 ) = (0) 𝐵𝑟 ∈ R 𝐵𝑟 , 𝑚 𝑖( 0 ) = (−∞) 𝐵𝑟 ∈ R 𝐵𝑟 .
 6:    for 1 ≤ 𝑗 ≤ 𝑇𝑐 do
 7:       Load K 𝑗 , V 𝑗 from HBM to on-chip SRAM.
                                      ( 𝑗)
 8:       On chip, compute S𝑖 = Q𝑖 K𝑇𝑗 ∈ R 𝐵𝑟 ×𝐵𝑐 .
                                                 ( 𝑗)                         ( 𝑗 −1)                             ( 𝑗)                           ( 𝑗)                 ( 𝑗)        ( 𝑗)
 9:        On chip, compute 𝑚 𝑖                           = max(𝑚 𝑖                      , rowmax(S𝑖 )) ∈ R 𝐵𝑟 , P̃𝑖                                    = exp(S𝑖             − 𝑚 𝑖 ) ∈ R 𝐵𝑟 ×𝐵𝑐
                               ( 𝑗)             𝑗 −1      ( 𝑗)      ( 𝑗 −1)                             ( 𝑗)
           (pointwise), ℓ𝑖            = 𝑒 𝑚𝑖           −𝑚𝑖
                                                                 ℓ𝑖           + rowsum( P̃𝑖 ) ∈ R 𝐵𝑟 .
                                               ( 𝑗)                       ( 𝑗 − 1)       ( 𝑗)           ( 𝑗 −1)            ( 𝑗)
10:     On chip, compute O𝑖 = diag(𝑒 𝑚𝑖 −𝑚𝑖 ) −1 O𝑖                                                                + P̃𝑖 V 𝑗 .
11:   end for
12:   On chip, compute O𝑖 = diag(ℓ𝑖(𝑇𝑐 ) ) −1 O𝑖(𝑇𝑐 ) .
13:   On chip, compute 𝐿 𝑖 = 𝑚 𝑖(𝑇𝑐 ) + log(ℓ𝑖(𝑇𝑐 ) ).
14:   Write O𝑖 to HBM as the 𝑖-th block of O.
15:   Write 𝐿 𝑖 to HBM as the 𝑖-th block of 𝐿.
16: end for
17: Return the output O and the logsumexp 𝐿.



Causal masking. One common use case of attention is in auto-regressive language modeling, where we
need to apply a causal mask to the attention matrix S (i.e., any entry S𝑖 𝑗 with 𝑗 > 𝑖 is set to −∞).
      1. As FlashAttention and FlashAttention-2 already operate by blocks, for any blocks where all
         the column indices are more than the row indices (approximately half of the blocks for large sequence
         length), we can skip the computation of that block. This leads to around 1.7-1.8× speedup compared
         to attention without the causal mask.
      2. We do not need to apply the causal mask for blocks whose row indices are guaranteed to be strictly less
         than the column indices. This means that for each row, we only need apply causal mask to 1 block
         (assuming square block).




                                                                                                    6
Correctness, runtime, and memory requirement. As with FlashAttention, Algorithm 1 returns
the correct output O = softmax(QK⊤ )V (with no approximation), using 𝑂 (𝑁 2 𝑑) FLOPs and requires 𝑂 (𝑁)
additional memory beyond inputs and output (to store the logsumexp 𝐿). The proof is almost the same as
the proof of Dao et al. [5, Theorem 1], so we omit it here.

3.1.2   Backward pass
The backward pass of FlashAttention-2 is almost the same as that of FlashAttention. We make a
minor tweak to only use the row-wise logsumexp 𝐿 instead of both the row-wise max and row-wise sum of
exponentials in the softmax. We include the backward pass description in Algorithm 2 for completeness.

Algorithm 2 FlashAttention-2 Backward Pass
Require: Matrices Q, K,        l V,m O, dO ∈ R
                                                  𝑁 ×𝑑 in HBM, vector 𝐿 ∈ R 𝑁 in HBM, block sizes 𝐵 , 𝐵 .
                                                                                                     𝑐 𝑟 l   m
 1: Divide Q into 𝑇𝑟 = 𝐵𝑁 blocks Q1 , . . . , Q𝑇𝑟 of size 𝐵𝑟 × 𝑑 each, and divide K, V in to 𝑇𝑐 = 𝐵𝑁 blocks
                                 𝑟                                                                         𝑐
    K1 , . . . , K𝑇𝑐 and V1 , . . . , V𝑇𝑐 , of size 𝐵𝑐 × 𝑑 each.
 2: Divide O into 𝑇𝑟 blocks O𝑖 , . . . , O𝑇𝑟 of size 𝐵𝑟 × 𝑑 each, divide dO into 𝑇𝑟 blocks dO𝑖 , . . . , dO𝑇𝑟 of size
    𝐵𝑟 × 𝑑 each, and divide 𝐿 into 𝑇𝑟 blocks 𝐿 𝑖 , . . . , 𝐿 𝑇𝑟 of size 𝐵𝑟 each.
 3: Initialize dQ = (0) 𝑁 ×𝑑 in HBM and divide it into 𝑇𝑟 blocks dQ1 , . . . , dQ𝑇𝑟 of size 𝐵𝑟 × 𝑑 each. Divide
    dK, dV ∈ R 𝑁 ×𝑑 in to 𝑇𝑐 blocks dK1 , . . . , dK𝑇𝑐 and dV1 , . . . , dV𝑇𝑐 , of size 𝐵𝑐 × 𝑑 each.
 4: Compute 𝐷 = rowsum(dO ◦ O) ∈ R𝑑 (pointwise multiply), write 𝐷 to HBM and divide it into 𝑇𝑟 blocks
    𝐷 1 , . . . , 𝐷 𝑇𝑟 of size 𝐵𝑟 each.
 5: for 1 ≤ 𝑗 ≤ 𝑇𝑐 do
 6:   Load K 𝑗 , V 𝑗 from HBM to on-chip SRAM.
 7:   Initialize dK 𝑗 = (0) 𝐵𝑐 ×𝑑 , dV 𝑗 = (0) 𝐵𝑐 ×𝑑 on SRAM.
 8:   for 1 ≤ 𝑖 ≤ 𝑇𝑟 do
 9:        Load Q𝑖 , O𝑖 , dO𝑖 , dQ𝑖 , 𝐿 𝑖 , 𝐷 𝑖 from HBM to on-chip SRAM.
                                      ( 𝑗)
10:        On chip, compute S𝑖 = Q𝑖 K𝑇𝑗 ∈ R 𝐵𝑟 ×𝐵𝑐 .
                               ( 𝑗)
11:      On chip, compute P𝑖 = exp(S𝑖 𝑗 − 𝐿 𝑖 ) ∈ R 𝐵𝑟 ×𝐵𝑐 .
                                          ( 𝑗)
12:      On chip, compute dV 𝑗 ← dV 𝑗 + (P𝑖 ) ⊤ dO𝑖 ∈ R 𝐵𝑐 ×𝑑 .
                            ( 𝑗)
13:      On chip, compute dP𝑖 = dO𝑖 V⊤𝑗 ∈ R 𝐵𝑟 ×𝐵𝑐 .
                                ( 𝑗)   ( 𝑗)     ( 𝑗)
14:     On chip, compute dS𝑖 = P𝑖 ◦ (dP𝑖 − 𝐷 𝑖 ) ∈ R 𝐵𝑟 ×𝐵𝑐 .
                                                                      ( 𝑗)
15:     Load dQ𝑖 from HBM to SRAM, then on chip, update dQ𝑖 ← dQ𝑖 + dS𝑖 K 𝑗 ∈ R 𝐵𝑟 ×𝑑 , and write back
        to HBM.
                                         ( 𝑗)⊤
16:     On chip, compute dK 𝑗 ← dK 𝑗 + dS𝑖 Q𝑖 ∈ R 𝐵𝑐 ×𝑑 .
17:   end for
18:   Write dK 𝑗 , dV 𝑗 to HBM.
19: end for
20: Return dQ, dK, dV.




Multi-query attention and grouped-query attention. Multi-query attention (MQA) [15] and grouped-
query attention (GQA) [1] are variants of attention where multiple heads of query attend to the same head of
key and value, in order to reduce the size of KV cache during inference. Instead of having to duplicate the
key and value heads for the computation, we implicitly manipulate the indices into the head to perform the
same computation. In the backward pass, we need to sum the gradients dK and dV across different heads
that were implicitly duplicated.

3.2     Parallelism
The first version of FlashAttention parallelizes over batch size and number of heads. We use 1 thread
block to process one attention head, and there are overall batch size · number of heads thread blocks. Each
thread block is scheduled to run on a streaming multiprocessor (SM), and there are 108 of these SMs on


                                                         7
an A100 GPU for example. This scheduling is efficient when this number is large (say ≥ 80), since we can
effectively use almost all of the compute resources on the GPU.
    In the case of long sequences (which usually means small batch sizes or small number of heads), to make
better use of the multiprocessors on the GPU, we now additionally parallelize over the sequence length
dimension. This results in significant speedup for this regime.

Forward pass. We see that the outer loop (over sequence length) is embarrassingly parallel, and we
schedule them on different thread blocks that do not need to communicate with each other. We also parallelize
over the batch dimension and number of heads dimension, as done in FlashAttention. The increased
parallelism over sequence length helps improve occupancy (fraction of GPU resources being used) when the
batch size and number of heads are small, leading to speedup in this case.
   These ideas of swapping the order of the loop (outer loop over row blocks and inner loop over column
blocks, instead of the other way round in the original FlashAttention paper), as well as parallelizing
over the sequence length dimension were first suggested and implemented by Phil Tillet in the Triton [17]
implementation.3

Backward pass. Notice that the only shared computation between different column blocks is in update dQ
                                                                                                 ( 𝑗)
in Algorithm 2, where we need to load dQ𝑖 from HBM to SRAM, then on chip, update dQ𝑖 ← dQ𝑖 + dS𝑖 K 𝑗 ,
and write back to HBM. We thus parallelize over the sequence length dimension as well, and schedule 1
thread block for each column block of the backward pass. We use atomic adds to communicate between
different thread blocks to update dQ.
    We describe the parallelization scheme in Fig. 2.




Figure 2: In the forward pass (left), we parallelize the workers (thread blocks) where each worker takes care
of a block of rows of the attention matrix. In the backward pass (right), each worker takes care of a block of
columns of the attention matrix.

  3 https://github.com/openai/triton/blob/main/python/tutorials/06-fused-attention.py




                                                      8
3.3    Work Partitioning Between Warps
As Section 3.2 describe how we schedule thread blocks, even within each thread block, we also have to decide
how to partition the work between different warps. We typically use 4 or 8 warps per thread block, and the
partitioning is described in Fig. 3.

Forward pass. For each block, FlashAttention splits K and V across 4 warps while keeping Q accessible
by all warps. Each warp multiplies to get a slice of QK⊤ , then they need to multiply with a slice of V and
communicate to add up the result. This is referred to as the “split-K” scheme. However, this is inefficient
since all warps need to write their intermediate results out to shared memory, synchronize, then add up the
intermediate results. These shared memory reads/writes slow down the forward pass in FlashAttention.
    In FlashAttention-2, we instead split Q across 4 warps while keeping K and V accessible by all warps.
After each warp performs matrix multiply to get a slice of QK⊤ , they just need to multiply with their shared
slice of V to get their corresponding slice of the output. There is no need for communication between warps.
The reduction in shared memory reads/writes yields speedup (Section 4).




                  (a) FlashAttention                                    (b) FlashAttention-2

                 Figure 3: Work partitioning between different warps in the forward pass


Backward pass. Similarly for the backward pass, we choose to partition the warps to avoid the “split-K”
scheme. However, it still requires some synchronization due to the more complicated dependency between all
the different inputs and gradients Q, K, V, O, dO, dQ, dK, dV. Nevertheless, avoiding “split-K” reduces shared
memory reads/writes and again yields speedup (Section 4).

Tuning block sizes Increasing block sizes generally reduces shared memory loads/stores, but increases
the number of registers required and the total amount of shared memory. Past a certain block size, register
spilling causes significant slowdown, or the amount of shared memory required is larger than what the GPU
has available, and the kernel cannot run at all. Typically we choose blocks of size {64, 128} × {64, 128},
depending on the head dimension 𝑑 and the device shared memory size.
    We manually tune for each head dimensions since there are essentially only 4 choices for block sizes, but
this could benefit from auto-tuning to avoid this manual labor. We leave this to future work.


4     Empirical Validation
We evaluate the impact of using FlashAttention-2 to train Transformer models.
• Benchmarking attention. We measure the runtime of FlashAttention-2 across different sequence
  lengths and compare it to a standard implementation in PyTorch, FlashAttention, and FlashAttention
  in Triton. We confirm that FlashAttention-2 is 1.7-3.0× faster than FlashAttention, 1.3-2.5×
  faster than FlashAttention in Triton, and 3-10× faster than a standard attention implementation.


                                                      9
       FlashAttention-2 reaches up to 230 TFLOPs/s, 73% of the theoretical maximum TFLOPs/s on A100
       GPUs.
• End-to-end training speed When used end-to-end to train GPT-style models of size 1.3B and 2.7B on
  sequence lengths either 2k or 8k, FlashAttention-2 yields up to 1.3× speedup compared to FlashAt-
  tention and 2.8× speedup compared to a baseline without FlashAttention. FlashAttention-2
  reaches up to 225 TFLOPs/s (72% model FLOPs utilization) per A100 GPU.

4.1                       Benchmarking Attention
We measure the runtime of different attention methods on an A100 80GB SXM4 GPU for different settings
(without / with causal mask, head dimension 64 or 128). We report the results in Fig. 4, Fig. 5 and Fig. 6,
showing that FlashAttention-2 is around 2× faster than FlashAttention and FlashAttention in
xformers (the “cutlass” implementation). FlashAttention-2 is around 1.3-1.5× faster than FlashAtten-
tion in Triton in the forward pass and around 2× faster in the backward pass. Compared to a standard
attention implementation in PyTorch, FlashAttention-2 can be up to 10× faster.
    Benchmark setting: we vary the sequence length from 512, 1k, ..., 16k, and set batch size so that the total
number of tokens is 16k. We set hidden dimension to 2048, and head dimension to be either 64 or 128 (i.e.,
32 heads or 16 heads). To calculate the FLOPs of the forward pass, we use:

                                                                             4 · seqlen2 · head dimension · number of heads.

With causal mask, we divide this number by 2 to account for the fact that approximately only half of the
entries are calculated. To get the FLOPs of the backward pass, we multiply the forward pass FLOPs by 2.5
(since there are 2 matmuls in the forward pass and 5 matmuls in the backward pass, due to recomputation).

                           Attention forward + backward speed (A100 80GB SXM4)                                                                         Attention forward + backward speed (A100 80GB SXM4)
                                Pytorch                                                                                                                     Pytorch
                                FlashAttention                                                                                                              FlashAttention                                       196             201           203
                    200         xformers                                                                                                       200          xformers                             187
                                                                                                      175            176
                                FlashAttention Triton                                 171                                                                   FlashAttention 173
                                                                                                                                                                           Triton




 Speed (TFLOPs/s)                                                                                                           Speed (TFLOPs/s)
                                                                      162
                                FlashAttention-2
                                               153                                                                                                          FlashAttention-2
                                                                                                                                                               151
                    150               132
                                                                                                                                               150

                                                             104 98          108             110            110
                                                102                                98              100            100                                                                      9590            9693            9795          9895
                    100      91 90           92                                                                                                100        83
                                                                                                                                                                           9185                                        8682            83
                                               73                76              77              75            75                                       76 78                          7676            7980
                                68                                                                                                                                     6772
                                                                                            46                                                         53
                    50                      40              43              45                                                                 50
                           36

                                                                                                         OOM                                                                                                                       OOM
                                512              1k              2k              4k              8k            16k                                          512             1k              2k              4k              8k           16k
                                                        Sequence length                                                                                                            Sequence length
                      (a) Without causal mask, head dimension 64                                                                                (b) Without causal mask, head dimension 128
                           Attention forward + backward speed (A100 80GB SXM4)                                                                         Attention forward + backward speed (A100 80GB SXM4)
                                Pytorch                                                                                                                     Pytorch
                                FlashAttention                                                                                                              FlashAttention
                    200         xformers                                                                                                       200          xformers                                                             182           189
                                FlashAttention Triton                                                                171                                    FlashAttention Triton                                173




 Speed (TFLOPs/s)                                                                                                           Speed (TFLOPs/s)
                                                                                                      165
                                FlashAttention-2                                      156                                                                   FlashAttention-2                     155
                    150                                               140                                                                      150                               133
                                                      119
                                                                                                            97                                                    99
                    100               88                                     87              92                                                100                                        82              87              91             92
                                                                                                                                                                                                                                       83 80
                                                             77 79               76                79            80                                                                                     76 74           80 78
                                             70 75             66              68                69            67                                                         72            69 68
                             585159            60                                                                                                       555850          62 61
                    50                                                                                                                         50                                      32              32              34
                                                                                                                                                       23              28
                           15               16              17              18              18
                                                                                                         OOM                                                                                                                       OOM
                                512              1k              2k              4k              8k            16k                                          512             1k              2k              4k              8k           16k
                                                        Sequence length                                                                                                            Sequence length
                          (c) With causal mask, head dimension 64                                                                                    (d) With causal mask, head dimension 128

                                                            Figure 4: Attention forward + backward speed on A100 GPU


                                                                                                                           10
                                      Attention forward speed (A100 80GB SXM4)                                                                               Attention forward speed (A100 80GB SXM4)
                                                                                                                                                                       224                 227              222              224           223
                                Pytorch                                                                                                                 Pytorch
                                                                                                                                                           209
                                FlashAttention                                                                                                          FlashAttention
                    200         xformers       191                193             192             192           192                       200           xformers
                                   178
                                FlashAttention Triton                                                                                                   FlashAttention Triton




 Speed (TFLOPs/s)                                                                                                      Speed (TFLOPs/s)
                                                                                                                                                                                                        157              160            163
                                FlashAttention-2              149             152             152           155                                         FlashAttention-2               152
                    150                           141                                                                                     150                             140
                                 128                                                                                                                     127                                          122              122           122
                                                                                                                                                                        115          120
                                                                         10499           10498          10498                                          107
                                             9694        9997
                    100      9189                                                                                                         100
                                                                                                                                                   69                66           71               71             6772             73
                                                                                                                                                                                60               63
                                                                                                                                                                   56
                    50                      34          35              37              37                                                50      42
                           29

                                                                                                    OOM                                                                                                                        OOM
                                512              1k          2k              4k              8k           16k                                          512              1k           2k               4k               8k            16k
                                                        Sequence length                                                                                                         Sequence length
                      (a) Without causal mask, head dimension 64                                                                           (b) Without causal mask, head dimension 128
                                      Attention forward speed (A100 80GB SXM4)                                                                               Attention forward speed (A100 80GB SXM4)
                                Pytorch                                                                                                                 Pytorch
                                FlashAttention                                                                                                          FlashAttention                                      198              200           197
                    200         xformers                                                                        183
                                                                                                                                          200           xformers                           187
                                                                                  177             181
                                FlashAttention Triton                                                                                                   FlashAttention 168
                                                                                                                                                                       Triton




 Speed (TFLOPs/s)                                                                                                      Speed (TFLOPs/s)
                                                                  167
                                FlashAttention-2
                                               146                                                                                                      FlashAttention-2                                                                148
                    150                                                                       137           143                           150                132                                        133
                                                                                                                                                                                                                         141
                                                                              131                                                                                                      126
                                      115                     112                                                                                                                                     112              115           117
                                                                                                                                                                         108         107
                                                 99                                                     9495                                                            95
                    100                        82          89            8992            9194                                             100            89
                                 78                      81                                                                                            79
                               71            70                                                                                                                                  65               68               70              71
                             56                                                                                                                                     59
                                                                                                                                                   49
                    50                                                                                                                    50
                                                                                                                                                  15               18           19               19               19
                           10               10          10              10              10
                                                                                                    OOM                                                                                                                        OOM
                                512              1k          2k              4k              8k           16k                                          512              1k           2k               4k               8k            16k
                                                        Sequence length                                                                                                         Sequence length
                          (c) With causal mask, head dimension 64                                                                               (d) With causal mask, head dimension 128

                                                                    Figure 5: Attention forward speed on A100 GPU


    Just running the same implementation on H100 GPUs (using no special instructions to make use of new
features such as TMA and 4th-gen Tensor Cores), we obtain up to 335 TFLOPs/s (Fig. 7). We expect that by
using new instructions, we can obtain another 1.5x-2x speedup on H100 GPUs. We leave that to future work.

4.2                       End-to-end Performance
We measure the training throughput of GPT-style models with either 1.3B or 2.7B parameters, on 8×A100
80GB SXM. As shown in Table 1, FlashAttention-2 yields 2.8× speedup compared to a baseline without
FlashAttention and 1.3× speedup compared to FlashAttention-2, reaching up to 225 TFLOPs/s per A100
GPU.
   Note that we calculate the FLOPs by the formula, following Megatron-LM [16] (and many other papers
and libraries):
                                            6 · seqlen · number of params + 12 · number of layers · hidden dim · seqlen2 .
The first term accounts for the FLOPs due to weight–input multiplication, and the second term accounts for
the FLOPs due to attention. However, one can argue that the second term should be halved, as with causal
mask we only need to compute approximately half the number of elements in attention. We choose to follow
the formula from the literature (without dividing the attention FLOPs by 2) for consistency.


5                    Discussion and Future Directions
FlashAttention-2 is 2× faster than FlashAttention, which means that we can train models with 16k
longer context for the same price as previously training a 8k context model. We are excited about how this can


                                                                                                                      11
                                   Attention backward speed (A100 80GB SXM4)                                                                                 Attention backward speed (A100 80GB SXM4)
                                  Pytorch                                                                                                                  Pytorch
                                  FlashAttention                                                                                                           FlashAttention                                                               196
                    200           xformers                                                                                                    200          xformers                                        187           193
                                                                                                                                                                                             175
                                  FlashAttention Triton                                               169           170                                    FlashAttention 159
                                                                                                                                                                          Triton




 Speed (TFLOPs/s)                                                                                                          Speed (TFLOPs/s)
                                                                                      163
                                  FlashAttention-2                    152                                                                                  FlashAttention-2
                    150                               141                                                                                     150                136
                                      120                                                                   113
                                                             106             109             112
                    100      91              90 92                                                                                            100                                                                97 90
                                                                   87              86              87           88                                                         84        86 88         888489          86 82       889181
                                  81                                                                                                                    7876                           79 77             80
                                                 67              70              70              69           68                                            68         7375 74
                                62                                                                                                                    59
                                                            48              49              51
                    50     39               43                                                                                                50

                                                                                                        OOM                                                                                                                 OOM
                                512              1k              2k              4k              8k           16k                                          512            1k            2k            4k            8k            16k
                                                        Sequence length                                                                                                           Sequence length
                      (a) Without causal mask, head dimension 64                                                                               (b) Without causal mask, head dimension 128
                                   Attention backward speed (A100 80GB SXM4)                                                                                 Attention backward speed (A100 80GB SXM4)
                                  Pytorch                                                                                                                  Pytorch
                                  FlashAttention                                                                                                           FlashAttention
                    200           xformers                                                                                                    200          xformers                                                                     186
                                                                                                                                                                                                                         176
                                  FlashAttention Triton                                                                                                    FlashAttention Triton




 Speed (TFLOPs/s)                                                                                                          Speed (TFLOPs/s)
                                                                                                      160           166                                                                                    165
                                  FlashAttention-2                                    149                                                                  FlashAttention-2                  145
                    150                                                                                                                       150
                                                                      131
                                                                                                                                                                               122
                                                      111
                                                                                                            98
                    100                                                      85
                                                                                             93                                               100                90                                                8684        8984
                                      81                     76 71                                                                                                                                   8080
                                             70 68                                                              68                                                                     7175                                        67
                             58 53                             60                6265            6267         60                                       5953              6365              58            63            66
                                               54                                                                                                                            52                                  49
                    50         46                                                                                                             50           43                        43            45
                                                                                                                                                      30               37
                           19               21              24              25              26
                                                                                                        OOM                                                                                                                 OOM
                                512              1k              2k              4k              8k           16k                                          512            1k            2k            4k            8k            16k
                                                        Sequence length                                                                                                           Sequence length
                          (c) With causal mask, head dimension 64                                                                                   (d) With causal mask, head dimension 128

                                                                        Figure 6: Attention backward speed on A100 GPU

Table 1: Training speed (TFLOPs/s/GPU) of GPT-style models on 8×A100 GPUs. FlashAttention-2
reaches up to 225 TFLOPs/s (72% model FLOPs utilization). We compare against a baseline running without
FlashAttention.

                          Model                                             Without FlashAttention                                              FlashAttention                                  FlashAttention-2
                    GPT3-1.3B 2k context                                         142 TFLOPs/s                                                    189 TFLOPs/s                                     196 TFLOPs/s
                    GPT3-1.3B 8k context                                         72 TFLOPS/s                                                     170 TFLOPs/s                                     220 TFLOPs/s
                    GPT3-2.7B 2k context                                         149 TFLOPs/s                                                    189 TFLOPs/s                                     205 TFLOPs/s
                    GPT3-2.7B 8k context                                          80 TFLOPs/s                                                    175 TFLOPs/s                                     225 TFLOPs/s


be used to understand long books and reports, high resolution images, audio and video. FlashAttention-2
will also speed up training, finetuning, and inference of existing models.
    In the near future, we plan to collaborate with researchers and engineers to make FlashAttention widely
applicable in different kinds of devices (e.g., H100 GPUs, AMD GPUs), as well as new data types such as
FP8. As an immediate next step, we plan to optimize FlashAttention-2 for H100 GPUs to use new hardware
features (TMA, 4th-gen Tensor Cores, fp8). Combining the low-level optimizations in FlashAttention-2 with
high-level algorithmic changes (e.g., local, dilated, block-sparse attention) could allow us to train AI models
with much longer context. We are also excited to work with compiler researchers to make these optimization
techniques easily programmable.




                                                                                                                          12
                           Attention forward + backward speed (H100 80GB SXM5)                                               Attention forward + backward speed (H100 80GB SXM5)
                                Pytorch                                                                                            Pytorch
                                FlashAttention                                                                                     FlashAttention           320        326        335         338
                                FlashAttention-2                               294         296                                     FlashAttention-2
                                                                                                                                               294
                    300                                             288                                              300




 Speed (TFLOPs/s)                                                                                 Speed (TFLOPs/s)
                                                         274
                                            254                                                                                    248
                                215
                    200                                                                                              200
                            157          159         161        161        166        168                                                                         160        167
                                                                                                                                                       145                     137       139
                                                                                                                               127        120127         128        131
                    100                                        86         87                                         100     93
                                       72           81
                           62

                                                                                     OOM                                                                                                OOM
                            512          1k          2k         4k         8k         16k                                     512           1k           2k         4k         8k        16k
                                                   Sequence length                                                                                    Sequence length
                      (a) Without causal mask, head dimension 64                                                      (b) Without causal mask, head dimension 128
                           Attention forward + backward speed (H100 80GB SXM5)                                               Attention forward + backward speed (H100 80GB SXM5)
                                Pytorch                                                                                            Pytorch
                                FlashAttention                                                                                     FlashAttention                                             328
                                FlashAttention-2                                                                                   FlashAttention-2                    294
                                                                                                                                                                                  308
                    300                                                                    284                       300




 Speed (TFLOPs/s)                                                                                 Speed (TFLOPs/s)
                                                                               273                                                                          265
                                                                    257
                                                         232                                                                                   221
                    200                     192                                                                      200
                                                                                      156                                          163
                                141                             138        149                                                                                                           137
                                                     136                                                                                                            126        135
                                         123
                            104                                                                                                   98        109          108
                    100                                                                                              100
                                                                                                                                                       57         61         63
                                                                                                                             40           50
                           26          29           31         32         32
                                                                                     OOM                                                                                                OOM
                            512          1k          2k         4k         8k         16k                                     512           1k           2k         4k         8k        16k
                                                   Sequence length                                                                                    Sequence length
                          (c) With causal mask, head dimension 64                                                          (d) With causal mask, head dimension 128

                                                   Figure 7: Attention forward + backward speed on H100 GPU


Acknowledgments
We thank Phil Tillet and Daniel Haziza, who have implemented versions of FlashAttention in Triton [17] and
the xformers library [10]. FlashAttention-2 was motivated by exchange of ideas between different ways that
attention could be implemented. We are grateful to the Nvidia CUTLASS team (especially Vijay Thakkar, Cris
Cecka, Haicheng Wu, and Andrew Kerr) for their CUTLASS library, in particular the CUTLASS 3.x release,
which provides clean abstractions and powerful building blocks for the implementation of FlashAttention-2.
We thank Driss Guessous for integrating FlashAttention to PyTorch. FlashAttention-2 has benefited
from helpful discussions with Phil Wang, Markus Rabe, James Bradbury, Young-Jun Ko, Julien Launay,
Daniel Hesslow, Michaël Benesty, Horace He, Ashish Vaswani, and Erich Elsen. Thanks for Stanford CRFM
and Stanford NLP for the compute support. We thank Dan Fu and Christopher Ré for their collaboration,
constructive feedback, and constant encouragement on this line of work of designing hardware-efficient
algorithms. We thank Albert Gu and Beidi Chen for their helpful suggestions on early drafts of this technical
report.


References
 [1] Joshua Ainslie, James Lee-Thorp, Michiel de Jong, Yury Zemlyanskiy, Federico Lebrón, and Sumit
     Sanghai. Gqa: Training generalized multi-query transformer models from multi-head checkpoints. arXiv
     preprint arXiv:2305.13245, 2023.
 [2] Iz Beltagy, Matthew E Peters, and Arman Cohan. Longformer: The long-document transformer. arXiv
     preprint arXiv:2004.05150, 2020.


                                                                                                 13
 [3] Beidi Chen, Tri Dao, Eric Winsor, Zhao Song, Atri Rudra, and Christopher Ré. Scatterbrain: Unifying
     sparse and low-rank attention. In Advances in Neural Information Processing Systems (NeurIPS), 2021.
 [4] Krzysztof Marcin Choromanski, Valerii Likhosherstov, David Dohan, Xingyou Song, Andreea Gane,
     Tamas Sarlos, Peter Hawkins, Jared Quincy Davis, Afroz Mohiuddin, Lukasz Kaiser, et al. Rethinking
     attention with performers. In International Conference on Learning Representations (ICLR), 2020.
 [5] Tri Dao, Daniel Y. Fu, Stefano Ermon, Atri Rudra, and Christopher Ré. FlashAttention: Fast and
     memory-efficient exact attention with IO-awareness. In Advances in Neural Information Processing
     Systems, 2022.
 [6] Zhe Jia and Peter Van Sandt. Dissecting the Ampere GPU architecture via microbenchmarking. GPU
     Technology Conference, 2021.
 [7] Zhe Jia, Marco Maggioni, Benjamin Staiger, and Daniele P Scarpazza. Dissecting the nvidia Volta GPU
     architecture via microbenchmarking. arXiv preprint arXiv:1804.06826, 2018.
 [8] Angelos Katharopoulos, Apoorv Vyas, Nikolaos Pappas, and François Fleuret. Transformers are RNNs:
     Fast autoregressive transformers with linear attention. In International Conference on Machine Learning,
     pages 5156–5165. PMLR, 2020.
 [9] Nikita Kitaev, Łukasz Kaiser, and Anselm Levskaya. Reformer: The efficient transformer. In The
     International Conference on Machine Learning (ICML), 2020.
[10] Benjamin Lefaudeux, Francisco Massa, Diana Liskovich, Wenhan Xiong, Vittorio Caggiano, Sean Naren,
     Min Xu, Jieru Hu, Marta Tintore, Susan Zhang, Patrick Labatut, and Daniel Haziza. xformers: A modular
     and hackable transformer modelling library. https://github.com/facebookresearch/xformers, 2022.
[11] Maxim Milakov and Natalia Gimelshein. Online normalizer calculation for softmax. arXiv preprint
     arXiv:1805.02867, 2018.
[12] OpenAI. Gpt-4 technical report. ArXiv, abs/2303.08774, 2023.
[13] Markus N Rabe and Charles Staats. Self-attention does not need 𝑂 (𝑛2 ) memory. arXiv preprint
     arXiv:2112.05682, 2021.
[14] Aurko Roy, Mohammad Saffar, Ashish Vaswani, and David Grangier. Efficient content-based sparse
     attention with routing transformers. Transactions of the Association for Computational Linguistics, 9:
     53–68, 2021.
[15] Noam Shazeer. Fast transformer decoding:         One write-head is all you need.        arXiv preprint
     arXiv:1911.02150, 2019.
[16] Mohammad Shoeybi, Mostofa Patwary, Raul Puri, Patrick LeGresley, Jared Casper, and Bryan Catanzaro.
     Megatron-LM: Training multi-billion parameter language models using model parallelism. arXiv preprint
     arXiv:1909.08053, 2019.
[17] Philippe Tillet, Hsiang-Tsung Kung, and David Cox. Triton: an intermediate language and compiler for
     tiled neural network computations. In Proceedings of the 3rd ACM SIGPLAN International Workshop
     on Machine Learning and Programming Languages, pages 10–19, 2019.
[18] Ashish Vaswani, Noam Shazeer, Niki Parmar, Jakob Uszkoreit, Llion Jones, Aidan N Gomez, Łukasz
     Kaiser, and Illia Polosukhin. Attention is all you need. Advances in neural information processing
     systems, 30, 2017.
[19] Sinong Wang, Belinda Z Li, Madian Khabsa, Han Fang, and Hao Ma. Linformer: Self-attention with
     linear complexity. arXiv preprint arXiv:2006.04768, 2020.
[20] Manzil Zaheer, Guru Guruganesh, Kumar Avinava Dubey, Joshua Ainslie, Chris Alberti, Santiago
     Ontanon, Philip Pham, Anirudh Ravula, Qifan Wang, Li Yang, et al. Big bird: Transformers for longer
     sequences. Advances in Neural Information Processing Systems, 33, 2020.


                                                     14