Flash Attention (2022)
Optimised kernel for computing attention on a GPU. Specifically, the NVIDIA A100.
Standard attention implementations write the full n×n attention matrix to GPU high-bandwidth memory (HBM), read it back for softmax, write again, then read for the final multiplication.
Flash Attention restructures computation around the GPU memory hierarchy.
The result is exact attention (no approximation) with memory complexity. Wall clock speeds also improved by 2–4× for long sequences.
Enabled training with 16K+ token sequences on hardware that previously supported only 2K. The maximum sequence length was no longer limited by the size of the attention matrix (which caused OOM errors) but only by the linear cost of storing the KV cache.
Parallel over Batch Dimension
~150 TFLOPS (FP16)
Flash Attention 2
Flash Attention 3
Tags: AI