Flash Attention

Last updated:
Created:

Flash Attention (2022)

  • Optimised kernel for computing attention on a GPU. Specifically, the NVIDIA A100.

  • Standard attention implementations write the full n×n attention matrix to GPU high-bandwidth memory (HBM), read it back for softmax, write again, then read for the final multiplication.

    1. Load QQ and KK from HBM to on-chip SRAM.
    2. Compute S=QKTS = QK^T.
    3. Write SS (size N2N^2) back to HBM.
    4. Load SS from HBM to SRAM.
    5. Compute P=softmax(S)P = \text{softmax}(S).
    6. Write PP back to HBM.
    7. Load PP and VV from HBM.
    8. Compute O=PVO = PV.
    9. Write OO to HBM.
  • Flash Attention restructures computation around the GPU memory hierarchy.

    1. Load small blocks of QQKK, and VV into SRAM. This technique is called tiling.
    2. Compute the attention scores, (online) softmax, and output within the fast SRAM incrementally.
    3. Ideally, never write the intermediate N×NN \times N matrix to HBM at all.
    • SRAM is 20MB with 19TB/s bandwidth versus 1.5TB/s for HBM.
  • The result is exact attention (no approximation) with O(n)O(n) memory complexity. Wall clock speeds also improved by 2–4× for long sequences.

  • Enabled training with 16K+ token sequences on hardware that previously supported only 2K. The maximum sequence length was no longer limited by the size of the attention matrix (which caused OOM errors) but only by the linear cost of storing the KV cache.

  • Parallel over Batch Dimension

  • ~150 TFLOPS (FP16)

  • Flash Attention 2

    • 2023
    • Added parallisation over the sequence dimension, allowing it to use more streaming multiprocessors on the GPU for long sequences.
    • Reduced the number of non-matrix-multiplication operations that do not use Tensor Cores, e.g. scaling and masking
    • ~350 TFLOPS (FP16)
  • Flash Attention 3

    • 2024
    • Optimised for NVIDIA H100 Hopper GPUs.
    • The H100 introduced the Tensor Memory Accelerator (TMA), which allows the GPU to copy memory from HBM to SRAM asynchronously, without blocking the compute threads.
    • warp specialisation:
      • Some threads (warps) act as producers (loading data via TMA) and others as consumers (computing GEMMs), creating a perfect pipeline where computation and data transfer happen simultaneously.
    • FP8 support
    • ~740 TFLOPS (FP16) / 1.2 PFLOPS (FP8)

Tags: AI