How transformer-based language models got longer context windows

Last updated:
Created:

In 2019, models such as BERT and GPT-2 were limited to relatively short inputs (e.g., 512 or 1024 tokens) due to the quadratic time/memory costs of attention. Nowadays, commercial language models have maximum input lengths of hundreds of thousands of tokens, with some reaching millions. How did this happen?

Understanding attention and its quadratic cost

Self-attention

In self-attention, the relationship between every token and every other token must be computed. The QK.TQ \cdot K.T operation that produces the n×nn \times n attention matrix requires n2dn^2 \cdot d operations, the n2n^2 results of which have to be stored and have the softmax operation applied to them. For a 512 token sequence, that means 262,144 values. For 100,000 tokens, 10,000,000,000 values would be needed.

This set a ceiling on the length of inputs we could practically feed to transformers. Models like BERT, GPT2, Roberta, and BART limited themselves to inputs no higher than about 1024.

Attempts to make more efficient transformers

In 2020, a series of papers aimed to increase the maximum supported input sequence length by sacrificing the all-to-all relationships of self-attention in a controlled way, either by limiting each token’s field of view (local/sparse attention) or by compressing the attention computation (low-rank or hashed approximations).

Limitations:

  • Some tasks that truly require global interactions could still suffer when using only sparse/local attention.
  • required custom CUDA kernels
  • struggled with autoregressive generation
  • quality degradatifon on tasks requiring precise long-range dependencies
  • Unsuitability for GPU hardware
    • Irregular memory access patterns
      • GPUs are Single Instruction, Multiple Data (SIMD) machines. They excel at performing the same operation on contiguous blocks of memory (dense matrix multiplication).
      • Sparse attention requires gathering non-contiguous data points from memory based on indices (random attention or specific sliding windows).
      • Sparse operations lead to uncoalesced memory access. The GPU might fetch a 128-byte cache line to use only 16 bytes of data, wasting memory bandwidth.
      • Implementing block-sparse attention required writing custom CUDA kernels. These kernels were often difficult to optimise and could not leverage the heavily optimised Tensor Core primitives designed for dense GEMM (General Matrix Multiply) operations.
      • https://arxiv.org/pdf/2306.01160
    • Wall clock reality
      • While sparse attention reduced the theoretical FLOPs, it often did not reduce the wall-clock training time.
      • The overhead of managing the sparsity patterns (masking, gathering indices) often outweighed the savings in compute.
      • A dense operation that uses 100% of the GPU's potential is often faster than a sparse operation that uses only 10%.
    • "Needle" Problem
      • Sparse attention assumes that not all token interactions are important
      • But in needle-in-a-haystack tasks, a single specific token interaction might be the only one that matters.
      • If the random or sliding window sparsity pattern fails to capture that specific connection, the model will fail.
    • The path forward was not to do less work (sparsity), but to do it more efficiently (IO-awareness).

Scaling long sequences with recurrence and memory

break long sequences into segments and propagate information across segments, effectively giving an “infinite” context length through recurrence.

Transformer-XL (2019) introduced segment-level recurrence. Instead of discarding the hidden states of the previous segment after processing, the model caches them in memory. It is then able to attend to the next segment along with the cached tokens from the one before. The Query is derived from the current segment, while the Key and Value vectors are derived from the concatenation of the previous segment's output (with gradients stopped) and the current segment.

Gradients are not back-propagated to the cached segment. But the forward pass still benefits from the history. This creates a sliding window of context that moves forward with every step.

Transformer-XL was also built on relative positional encodings. This was needed to get around the cached tokens from segment A and the tokens of segment A clashing if absolute positional encodings were used. In Transformer-XL, the relative distance between the Query and Key tokens is what is used.

Compressive Transformer (2019) extended the ideas in the Transformer-XL paper. The authors distinguish between episodic memory, focused on detailed recollection of recent events, and semantic memory for storing compressed and more abstract concepts. The former works the same way as Transformer-XL's memory. The latter stores compressed states when the primary memory fills up. The compression is performed using a 1D convolution that was empirically shown to work better than, e.g. max or mean pooling.

During the forward pass, the Query vector attends to a concatenation of the current segment, the primary memory and the compressed memory. Because the compression function is differentiable, the model learns during training how to compress memories in a way that preserves the underlying information.

Understanding the challenges of positional embeddings

Fixed-length embeddings mean that a model trained on short sequences might not generalise to longer ones.

Positional Encoding to support extended context windows

Early transformer models like BERT and GPT used learned absolute positional embeddings Xpos=Xtoken+PposX_{pos} = X_{token} + P_{pos}, where PposP_{pos} is a learned vector unique to each position index (0, 1, 2... 511).

They had no representation for positions beyond their training length. Their embeddings assigned each position index a learned vector. Position 513 was undefined or random if training stopped at 512. The model had no concept of "distance," only of unique indices. This prevented any generalisation to lengths longer than the training set.

Sinusoidal encodings theoretically could extrapolate, but in practice, perplexity collapsed on sequences even slightly beyond training length.

  • ALiBi (Attention with Linear Bias)

    • 2021
    • https://arxiv.org/pdf/2108.12409
    • a simple method to encode positions that allows a model trained on shorter sequences to generalise to longer ones.
    • Instead of using explicit positional embeddings added to tokens, ALiBi introduces a static, non-learned penalty to the attention scores that grows linearly with the distance between query and key. Scoreij=QiKjTmij\text{Score}_{ij} = Q_i K_j^T - m \cdot |i-j| where mm is a slope specific to each attention head.
    • penalises attention based on how far apart tokens are, biasing the model toward more recent tokens.
    • Bias does not depend on an absolute position index, so the model can naturally extend to longer sequences than it saw in training. You just continue applying the linear penalty beyond the training length.
    • Different attention heads use different penalty slopes (following a geometric series), allowing some heads to focus locally while others attend broadly.
    • demonstrated training on 1,024 tokens and evaluation successfully at 8–16 times that length.
    • ALiBi models generalise to longer sequences better than any other method zero-shot. A model trained on 1,024 tokens works perfectly on 2,048 tokens because the penalty mechanism is inherently length-agnostic. It was used in MPT (MosaicML) and BLOOM. However, its rigidity (linear decay) makes it potentially less expressive than the complex-valued rotations of RoPE, which can learn more nuanced distance relationships. RoPE's flexibility, combined with YaRN-style fixes, ultimately won the adoption war for Llama 3 and GPT-4.
  • Relative Positional Embeddings

    • introduced with Transformer-XL and further refined in subsequent models
    • remove the fixed limit of absolute positions. Instead of embedding the absolute index, they embed the distance between tokens (often clipped to some maximum).
    • The model learns to weight interactions based on the distance between tokens, rather than their absolute positions in the sequence.
    • The model can slide to text of greater length without encountering unseen position indices.
  • Rotary position embeddings (RoPE)

    • applies a rotating transform to queries/keys that encodes positions implicitly
    • dot products are invariant to shared rotations: if query and key vectors are both rotated by their respective positions, the attention score depends only on relative position.
    • Each pair of embedding dimensions is treated as a complex number and rotated by an angle proportional to position.
      • For a token at position mm and a feature dimension pair, the rotation is defined as: f(x,m)=xeimθf(x, m) = x \cdot e^{im\theta} In matrix form for pairs of dimensions:(q0q1)=(cosmθsinmθsinmθcosmθ)(q0q1)\begin{pmatrix} q'_0 \\ q'_1 \end{pmatrix} = \begin{pmatrix} \cos m\theta & -\sin m\theta \\ \sin m\theta & \cos m\theta \end{pmatrix} \begin{pmatrix} q_0 \\ q_1 \end{pmatrix}
      • The attention score (dot product) between a Query at position mm and a Key at position nn becomes:f(q,m),f(k,n)=qTkcos((mn)θ)\langle f(q, m), f(k, n) \rangle = q^T k \cdot \cos((m-n)\theta)
      • The result depends on the relative distance (mn)(m-n), not on the absolute values of mm or nn. This property, known as Translation Invariance, theoretically allows the model to handle sequences of infinite length, as the interaction mechanism is identical regardless of where in the sequence the tokens appear.
    • Different dimension pairs rotate at different frequencies, creating a rich positional signal that decays naturally with distance.
    • better generalisation to longer sequences than naive sinusoids
    • extensions:
      • XPos (2023) improves its extrapolation ability by rescaling the rotation frequency for longer lengths
    • Adopted more widely than ALiBi
    • Even with RoPE, models trained on 4K contexts failed beyond that length, as rotation angles extrapolated to untested regimes caused attention score instability.

While RoPE is relative, models still failed when mnm-n exceeded the maximum distance seen during training.

  • The rotation frequencies θd\theta_d typically decay exponentially across the dimension dd.
  • At long distances, the high-frequency dimensions rotate so rapidly that the model encounters "phases" (angles) it has never correlated with semantic meaning
  • led to high attention uncertainty and "garbage" outputs.
  • known as the extrapolation cliff

https://arxiv.org/pdf/2309.16039

https://arxiv.org/pdf/2406.13282v1

  • Position Interpolation

    • 2023
    • https://arxiv.org/pdf/2306.15595
    • rescales position indices to fit within the original training range. We map the range to m=m/sm' = m / s where ss is the scale factor.
    • A model trained on 4K contexts, when extended to 32K, simply divides all positions by 8, mapping the new 32K range into the original 0–4K range.
    • One can take a model like LLaMA-2 (trained on 4K context with RoPE) and interpolate the RoPE angles to a 16K or 32K context range, then fine-tune on some data with the extended context. Only takes about 1000 steps.
    • effectively “stretches” the learned positional representation to cover a longer sequence
    • Empirically, this can work surprisingly well, albeit typically with some drop in perplexity that can often be recovered through a bit of additional training on long sequences
    • This kind of post-hoc extension is a testament to how much of the mechanics for long context can be achieved by adjusting positional encodings and continuing training, without needing to redesign the entire architecture
    • But it did reduce the "resolution" of the attention. By squeezing the positions, adjacent tokens became harder to distinguish, degrading performance on tasks requiring high local precision.
  • Neural Tangent Kernel (NTK) aware scaling

    • https://www.reddit.com/r/LocalLLaMA/comments/14lz7j5/ntkaware_scaled_rope_allows_llama_models_to_have
    • Uniform interpolation damages high-frequency positional signals. We should not interpolate all dimensions equally
      • High-Frequency Dimensions (Local Context): These encode information like "token A is next to token B." Interpolating (stretching) these destroys local syntax. We should extrapolate these (keep them as is).
      • Low-Frequency Dimensions (Global Context): These encode long-range distance. These are the ones that hit the "limit" and need to be interpolated (stretched) to fit the new length.
    • We apply a non-linear transform to the RoPE base frequency
      • effectively "slows down" the rotation of the low-frequency components (allowing them to stretch to infinity) while keeping the high-frequency components fast (preserving local resolution)
  • YaRN (Yet another RoPE extensioN)

    • 2023
    • https://arxiv.org/pdf/2309.00071
    • refined NTK-Awareness by addressing the Entropy problem
      • As the sequence length NN grows, the attention distribution (softmax) naturally flattens because the probability mass is spread over more tokens.
      • The model becomes "unsure" (high entropy)
    • YaRN introduces a Temperature Scaling factor tt (where t<1t < 1): Attention=softmax(QKTdt)\text{Attention} = \text{softmax}\left(\frac{QK^T}{\sqrt{d} \cdot t}\right)
    • By dividing by a smaller number (effectively multiplying the logits), YaRN "sharpens" the distribution, counteracting the entropy increase caused by the longer sequence.
    • Combined with a sophisticated "NTK-by-parts" interpolation strategy, enabled extending Llama 2 to a 128K context with just 0.1% of pre-training data
    • Dynamic variants adjust scaling based on current sequence length, enabling zero-shot context extension without fine-tuning.

Memory-efficient attention

https://arxiv.org/pdf/2112.05682

The bottleneck in attention is not the arithmetic of QKTQK^T. It is the reading and writing of the N×NN \times N matrix to HBM.

Flash Attention

During autoregressive generation, each new token requires attending to all previous tokens.

  • Storing previously computed key and value vectors (the "KV-cache") avoids recomputation but grows linearly with sequence length.

  • For a model with hidden dimension dmodeld_{model}, LL layers, and batch size BB, the memory for a sequence of length NN is:

    MemoryKV=2×N×L×dmodel×Precision (bytes)\text{Memory}_{KV} = 2 \times N \times L \times d_{model} \times \text{Precision (bytes)}

    For Llama 3 70B (approx config: L=80L=80, d=8192d=8192), a 128k token context requires: 2×128,000×80×8192×2335 GB2 \times 128,000 \times 80 \times 8192 \times 2 \approx 335 \text{ GB}. This exceeds the memory of a single H100 (80GB). A single request requires 4+ GPUs just for memory.

  • Attacking the raw size of the KV cache

    • Multi-head attention (MHA)
      • Traditional. Uses separate Key and Value projections for each head, requiring lots of memory
    • Multi-Query Attention
      • 2019
      • https://arxiv.org/pdf/1911.02150
      • Shares a single key-value head across all query heads, while still having multiple Query projections.
      • Reduces KV-cache size by 8–32x depending on head count.
      • Low memory but degrades performance
    • Grouped-Query Attention (GQA)
      • 2023
      • https://arxiv.org/pdf/2305.13245
      • Divide query heads into groups that share key-value projections.
      • Llama 2's 70B model uses 8 key-value heads for 64 query heads, resulting in an 8x cache reduction with minimal quality loss.
      • Is a compromise between memory and performance. Standard in Llama 2/3, Mistral, and most modern inference-optimised models.
  • Paged Attention

    • 2023
    • https://arxiv.org/pdf/2309.06180
    • Traditional approaches pre-allocate contiguous memory for maximum sequence length, wasting 60–80% of KV-cache memory.
    • borrows from operating system virtual memory
      • divide KV-cache into fixed-size blocks e.g. 16 tokens
      • A "Block Table" maps the logical sequence (tokens 1, 2, 3...) to physical blocks scattered non-contiguously in GPU memory.
      • Memory is allocated block-by-block only as the sequence grows.
    • achieves near-zero waste (under 4%) and enables 2–4x throughput improvements.
  • KV-cache compression

    • INT8 quantisation halves memory with negligible quality loss; INT4 quarters it with minor degradation
    • KIVI (per-channel key quantisation, per-token value quantisation) achieve 2-bit compression without quality degradation.
    • H2O (Heavy Hitter Oracle)
      • attention is approximately 95% sparse
      • dynamically retaining only recent tokens plus "heavy hitters" (tokens with the highest accumulated attention)
      • enables 5–20x cache reduction
  • StreamingLLM

    • 2024
    • https://arxiv.org/pdf/2309.17453
    • initial tokens serve as "attention sinks", absorbing unused attention probability mass due to softmax normalisation
    • Maintaining these sinks plus a sliding window enables theoretically infinite generation without fine-tuning, though it does not provide true long-term memory
    • tested to 4 million tokens
  • LongNet

    • 2023
    • https://arxiv.org/pdf/2307.02486
    • used two GPUs to train a model on extremely long sequences
    • took advantage of the fact that their attention mechanism allowed constant-size communication per device

Scaling hardware and model parallelism

Even with FlashAttention (memory efficient) and RoPE (mathematically sound), physical hardware limits persist. The H100 GPU has 80GB of memory. A 70B parameter model uses ~140GB just for weights (FP16). The KV cache for a 1 million token context adds another ~100GB+ depending on architecture. No single GPU can hold this.

To handle "infinite" contexts (1M+ tokens), the sequence itself must be distributed across multiple GPUs. This domain is known as Sequence Parallelism.

  • DeepSpeed Ulysses

    • 2023
    • https://arxiv.org/pdf/2309.14509
    • partitions the sequence along the Attention Head dimension.
      • If a model has 64 heads and we have 8 GPUs, Ulysses ensures that each GPU computes the full attention for 8 specific heads over the entire sequence.
      • requires a massive "All-to-All" transpose operation. Before attention, GPUs exchange tokens so that each GPU has the full sequence for its subset of heads. After attention, they transpose back.
      • Parallelism is hard-capped by the number of attention heads. If the model has 64 heads, you cannot use more than 64 GPUs. For extremely long sequences, the memory per GPU might still be too high even with maximum splitting
  • Ring Attention

    • 2023
    • https://arxiv.org/pdf/2310.01889
    • https://arxiv.org/pdf/2411.01783
    • partitions the sequence across any number of devices, arranged in a conceptual ring
      1. The sequence Q,K,VQ, K, V is split into chunks. GPU 1 gets chunk 1, GPU 2 gets chunk 2, etc.
      2. Each GPU computes attention between its local Query chunk and its local KV chunk.
      3. GPU 1 sends its KV chunk to GPU 2, GPU 2 sends its KV chunk to GPU, GPU NN sends to GPU 1.
      4. While the KV chunks are travelling over the NVLink/InfiniBand interconnect, the GPUs are simultaneously computing attention with the KV chunk they currently hold.
      5. This repeats NN times until every Query has attended to every Key in the distributed system.
    • Allows context length to scale linearly with the number of GPUs
    • The only cost is the latency of the ring pass, which is effectively hidden by the compute-heavy attention calculation.
    • Meta's context parallelism work (2024) refined Ring Attention for production inference, achieving 1 million token pre-fill with Llama 3.1 405B in 77 seconds on 128 H100s at 93% efficiency.

Long context training techniques

  • Curriculum Learning

    • Training on long context tasks from the start is inefficient and unstable.
    • Standard recipe is something like:
      • Train on standard context (4k or 8k) for 90% of the training steps. This teaches language, logic, and facts.
      • Continue pre-training on progressively longer sequences (16k -> 64k -> 128k).
      • During LCFT, the RoPE base frequency is often increased (e.g., from 10,000 to 500,000) to accommodate the new lengths
  • Short context dataset balancing

    • models trained exclusively on long data "forget" how to handle short tasks
    • data mix during LCFT must remain balanced
  • Stabilising logit growth

    • As sequences grow, the attention logits (pre-softmax scores) tend to drift upward in magnitude.

    • Logits    SoftmaxOne-Hot\text{Logits} \rightarrow \infty \implies \text{Softmax} \rightarrow \text{One-Hot}

    • This collapse causes the gradients to vanish (Z-loss spikes) and training to diverge.

    • Mitigation:

      • Z-Loss Regularisation: A penalty term encouraging logexp(x)\log \sum \exp(x) to remain small.
      • Logit Soft-Capping: Explicitly capping the logits using a tanh function: logits=captanh(logits/cap)\text{logits} = \text{cap} \cdot \tanh(\text{logits} / \text{cap}) This technique, used in Gemma 2 and Llama 3 prototypes, prevents the attention mechanism from "saturating" and ensures healthy gradient flow even at 100k+ tokens.

Possible successors to attention

  • Mega

    • 2022
    • attention with exponential decay
  • RetNet

  • State-space models

    • S4
    • Hydra
    • Mambda
      • replaces Attention with a selective scan mechanism
      • Throughput scales linearly O(N)O(N)
      • The "KV Cache" does not grow. It is a fixed-size vector.
      • struggles with "Copying" tasks. Because it compresses the entire history into a fixed-size state, it cannot perfectly recall a specific token from 100k steps ago if that token was not deemed "important" at the time.
        • The Transformer, by keeping the full history (KV Cache), has "perfect memory".
    • Jamba
      • Interleave Mamba layers (for throughput and memory savings) with occasional Transformer Attention layers (for retrieval capability).
  • local+global models

    • GPT-J’s retrieval
  • Perceiver

    • A multi-modal model that uses a latent bottleneck to handle very large inputs by projecting them to a smaller set of latent vectors.

Other papers to read

Tags: AI