Flash Attention 2 Explained
Flash Attention does not change the attention math. It changes how the math executes on a GPU. The result is 2-4x speedup, less memory, and the long contexts modern LLMs ship with.
The memory-bandwidth problem
Standard attention computes Q·K^T as an N×N matrix, applies softmax, then multiplies by V. For sequence length 4096, that’s a 16M-element intermediate matrix. For sequence length 32768, it’s 1 billion elements.
This intermediate is way too big for GPU on-chip memory (SRAM, fast). It lives in HBM (GPU main memory, slower). Every read and write is bandwidth-limited. The arithmetic itself is fast; the memory traffic is the bottleneck.
For modern GPUs, naive attention spends 90% of its time waiting on memory.
The tiling trick
Flash Attention reorganises the computation. Instead of materialising the whole N×N attention matrix, it processes the matrix in tiles small enough to fit in SRAM. Each tile is loaded, processed (softmax + matrix multiply), and reduced into the output, all without writing the intermediate to HBM.
The mathematical trick: a stable online softmax that can be incrementally updated as new tiles arrive. Each tile contributes to the output without ever needing the full normalisation constant up front.
Result: same output, dramatically less HBM traffic. Speed up 2-4x. Memory use linear in N instead of quadratic.
Flash Attention 2 (2023)
Flash Attention 2 refined the implementation:
- Better thread/warp scheduling on Hopper (H100) and Ampere (A100) GPUs.
- Reordered loops for higher occupancy.
- Native support for causal masking with no extra cost.
- Variable-length sequences in the same batch (essential for inference batching).
FA2 is 1.5-2x faster than FA1. It’s the default in modern training stacks and most inference servers (vLLM, TGI).
Why it changed the field
Long context windows (Claude’s 1M, Gemini’s 2M) are not architecturally novel. They’re Flash Attention-enabled. Without FA, attention at 1M tokens would consume terabytes of HBM and run for minutes per request. With FA, it’s linear in memory and seconds in time.
Every paper since 2023 about “long-context reasoning” or “million-token model” depends on Flash Attention or a derivative as a critical infrastructure piece.
Limits
FA2 is still O(N^2) in compute, just with better memory locality. For truly long contexts (multi-million tokens), even FA2 strains. Newer techniques:
- Ring Attention: distributes attention across many GPUs for very long sequences.
- Sparse attention: skip computation for token pairs that almost certainly don’t matter.
- Hierarchical attention: combine fine-grained recent attention with coarse-grained older context.
Flash Attention is the floor of the modern transformer stack. The frontier is what’s being built on top of it.