Towards Infinite Context: How LLMs Are Breaking the Context Limit
A comprehensive guide to extending LLM context windows through position encodings, efficient attention, and memory augmented architectures.
Table of Contents
- The Context Length Problem
- Quadratic Attention Complexity
- Position Encoding Generalization
- Memory and Information Retrieval
- Part 1: Position Encoding for Length Generalization
- Absolute Position Embeddings
- Rotary Position Embeddings (RoPE)
- RoPE Scaling: Position Interpolation, NTK-Aware, YaRN
- ALiBi: Attention with Linear Biases
- Part 2: Efficient Attention for Long Context
- Sliding Window Attention
- Landmark Attention
- Ring Attention
- Attention Sinks (StreamingLLM)
- H2O, Scissorhands, SnapKV, PyramidKV, FastGen
- Infini-Attention
- LongRoPE
- Part 3: Memory Systems
- Retrieval-Augmented Generation (RAG)
- Memorizing Transformers
- Part 4: Context Compression
- Gisting / Prompt Compression
- ICAE: In-Context Autoencoder
- LongLLMLingua: Selective Pruning
- Part 5: Recurrent and Streaming Approaches
- RWKV
- Mamba / State Space Models
- Transformer-XL
- Comparison: What Works When?
- My Take: The Real Limits
- References
ChatGPT was launched in 2023 with 4k context length, fast forward to 2025 Gemini 3 has more than 1M context window. So how did we bridge this 250x gap? This post explores the key innovations: position embeddings that generalize, attention mechanisms that scale, memory systems that extend beyond the context window, and the architectural changes making near-infinite context a reality. Long context is a complex problem that can be tackled from multiple perspectives. I've broken it down into five key areas, each addressing a different aspect of the challenge. Fair warning this one is gonna be long!!!
The Context Length Problem
Why is extending context hard? Three fundamental challenges:
1. Quadratic Attention Complexity
Standard attention is in sequence length, and here's why that matters. The attention mechanism computes a similarity score between every query token and every key token. With tokens, you have queries, each comparing against keys—that's comparisons. Each comparison involves a dot product over dimensions, giving us floating-point operations.
At 1M tokens, the numbers get ridiculous:
For a typical model with per head and 32 heads, that's roughly operations just for attention. Even worse, you need to store the full attention matrix in memory. At 1M tokens with FP16, that's about 2TB of memory just for one attention matrix—and you have dozens of layers.
FlashAttention helps by avoiding materializing the full matrix, but you're still doing computation. The fundamental quadratic cost remains, making very long contexts computationally expensive even with optimizations.
2. Position Encoding Generalization
Models trained on 4K tokens don't automatically work on 100K tokens. Position encodings must generalize to unseen lengths.
During training, the model learns position embeddings only for the sequence lengths it has seen. A model trained on sequences up to 4,096 tokens has never encountered position 10,000—it's out-of-distribution. When you feed it a 100K token sequence, things break down: absolute position embeddings assign garbage values to positions they've never seen, learned relative distances that worked perfectly for "distance 4,000" suddenly mean nothing at "distance 50,000", and attention scores become unstable because the model can't properly compute relative positions between tokens far beyond its training range.
The solution requires position encoding schemes that extrapolate naturally—mathematical functions (like rotations in RoPE or linear biases in ALiBi) that are well-defined for any position, not just those seen during training. Without proper generalization, performance degrades on longer sequences.
3. Memory and Information Retrieval
Even if we solve the computational complexity, there's an information retrieval challenge: with 1M tokens, most of the context is irrelevant to any given query. Standard attention treats all tokens equally—each query attends to all 1M keys, which can dilute the signal with noise.
The challenge is identifying which tokens matter. This requires:
- Hierarchical attention patterns that prioritize recent or semantically important tokens
- Retrieval mechanisms that can quickly locate relevant chunks
- Memory systems that compress or summarize less-relevant history
- Selective attention that learns to ignore most of the context
Without these mechanisms, a 1M token context can become less useful as the model struggles to focus on relevant information, and performance may degrade despite having more context available.
Part 1: Position Encoding for Length Generalization
The first breakthrough: making position embeddings that extrapolate beyond training length.
Absolute Position Embeddings (The Old Way)
Original transformers add learned position embeddings:
Problem: If you train with positions 0-2047, position 2048 is out-of-distribution garbage.
Rotary Position Embeddings (RoPE)
The game-changer. Instead of adding positions, RoPE rotates query and key vectors:
where is a rotation matrix based on position :
The attention score between positions and depends only on their relative distance:
Fig 1: RoPE Encoding Diagram
import torch
def precompute_rope_frequencies(dim: int, max_seq_len: int, base: float = 10000.0):
"""Precompute RoPE rotation frequencies"""
# Compute theta values for each dimension pair
freqs = 1.0 / (base ** (torch.arange(0, dim, 2).float() / dim))
# Compute position * frequency for all positions
positions = torch.arange(max_seq_len)
angles = torch.outer(positions, freqs) # [seq_len, dim/2]
# Return cos and sin
return torch.cos(angles), torch.sin(angles)
def apply_rope(x: torch.Tensor, cos: torch.Tensor, sin: torch.Tensor):
"""Apply rotary embeddings to queries or keys"""
# x: [batch, seq, heads, dim]
# Split into pairs for rotation
x1, x2 = x[..., ::2], x[..., 1::2]
# Apply rotation
rotated = torch.stack([
x1 * cos - x2 * sin,
x1 * sin + x2 * cos
], dim=-1).flatten(-2)
return rotatedWhy it works: Rotations are well-defined for any position. The model learns to interpret relative rotations, which generalize beyond training length.
RoPE Scaling: Extending Trained Models
Even RoPE has limits—models trained on 4K struggle at 32K. Solutions:
Position Interpolation (Linear Scaling)
Simply rescale positions to fit within training range:
A model trained on 4K tokens, when scaled to 32K, sees position 32000 as position 4000.
def scaled_rope(position: int, scale_factor: float):
"""Linear position interpolation"""
return position / scale_factorProblem: Compressed positions reduce resolution. Nearby tokens become indistinguishable.
NTK-Aware Scaling
Instead of scaling positions, scale the frequency base:
where is the scaling factor. This preserves high-frequency (local) information while extending low-frequency (global) range.
def ntk_scaled_rope(dim: int, max_seq_len: int, base: float = 10000.0, scale: float = 1.0):
"""NTK-aware RoPE scaling"""
# Scale the base frequency
base = base * (scale ** (dim / (dim - 2)))
freqs = 1.0 / (base ** (torch.arange(0, dim, 2).float() / dim))
positions = torch.arange(max_seq_len)
angles = torch.outer(positions, freqs)
return torch.cos(angles), torch.sin(angles)YaRN (Yet another RoPE extensioN)
The state-of-the-art. Combines NTK scaling with attention temperature adjustment:
where temperature is tuned per-layer to restore attention entropy.
ALiBi: Attention with Linear Biases
Alternative to RoPE: add a linear bias based on distance directly to attention scores:
where is a head-specific slope.
def create_alibi_bias(num_heads: int, seq_len: int):
"""Create ALiBi attention bias matrix"""
# Slopes: geometric sequence from 2^(-8/n) to 2^(-8)
slopes = 2 ** (-8 * torch.arange(1, num_heads + 1) / num_heads)
# Distance matrix
positions = torch.arange(seq_len)
distances = positions.unsqueeze(0) - positions.unsqueeze(1) # [seq, seq]
distances = distances.abs()
# Bias: [num_heads, seq, seq]
bias = -slopes.view(-1, 1, 1) * distances.unsqueeze(0)
return biasAdvantage: Extrapolates naturally—linear decay extends to any length.
Disadvantage: The linear decay assumption may not match all tasks. Some heads might want non-monotonic attention patterns.
Part 2: Efficient Attention for Long Context
Position encodings solves how to represent long contexts mathematically—RoPE and ALiBi let us extrapolate to sequences far beyond training length. But here's the catch: being able to represent a million tokens doesn't mean we can afford to compute attention over them. The quadratic complexity that made attention powerful at short lengths becomes a computational nightmare at scale. I've discussed this in details in my other article "on solving quadratic complexity of attention".
Sliding Window Attention
In Nutshell: each token only attends to a local window of nearby tokens instead of the entire sequence. Think of it like reading a book through a small window—you can only see a few pages at a time, but you can slide the window to read the whole book.
The key insight is that most information needed for generating the next token is local. A word typically depends on nearby words, not tokens thousands of positions away. By restricting attention to a fixed window size , we reduce complexity from to —linear in sequence length, constant in window size.
Each token at position attends only to tokens within its window:
The window slides along the sequence, creating a "local receptive field" for each position. For a 4K window, token 10,000 can directly attend to tokens 8,000-12,000, but not to token 1,000 or token 50,000. Information can still propagate long distances, but it must "hop" through multiple layers—each layer can move information by at most the window size.
Comparison of regular causal attention and sliding window attention patterns
def sliding_window_attention(Q, K, V, window_size: int):
"""Sliding window attention with O(n*w) complexity"""
batch, seq_len, num_heads, head_dim = Q.shape
outputs = []
for i in range(seq_len):
start = max(0, i - window_size // 2)
end = min(seq_len, i + window_size // 2 + 1)
q_i = Q[:, i:i+1, :, :]
k_window = K[:, start:end, :, :]
v_window = V[:, start:end, :, :]
scores = torch.matmul(q_i, k_window.transpose(-2, -1)) / (head_dim ** 0.5)
attn = F.softmax(scores, dim=-1)
out = torch.matmul(attn, v_window)
outputs.append(out)
return torch.cat(outputs, dim=1)Used by: Mistral (4K window), Gemma 2 (alternating local/global)
Limitation: This creates an interesting trade-off: local attention is fast and memory-efficient, but long-range dependencies require deep networks. For a 1M token context with a 4K window, information needs roughly 250 layer hops to travel from the beginning to the end. Some models like Mistral use pure sliding windows, while others like Gemma 2 alternate between local and global attention layers to balance efficiency and long-range modeling.
Landmark Attention
Instead of attending to all tokens, insert special "landmark" tokens that summarize chunks of the sequence. Think of landmarks as a table of contents—you first check the summary to find which chapter is relevant, then dive into that specific section.
The key idea is hierarchical attention: first route through landmarks to identify relevant chunks, then retrieve detailed information from those chunks. This creates a two-stage retrieval process that's much more efficient than full attention.
The sequence is structured with landmarks inserted between chunks:
Each landmark token is created by pooling (typically mean pooling) the tokens in its corresponding chunk, creating a compressed representation. When a query needs information, it first attends to the landmarks—a small set of summary tokens—to identify which chunks are relevant. Then it attends within those selected chunks to get the detailed information.
This two-stage process dramatically reduces computation: instead of comparing against all tokens, the query compares against landmarks (where ), then attends to only the relevant chunks. For a 1M token sequence divided into 1000 chunks, attention complexity drops from to roughly where is the number of selected chunks and is the chunk size.
Comparison of standard attention and attention with landmarks
class LandmarkAttention(nn.Module):
def __init__(self, d_model: int, chunk_size: int = 512):
super().__init__()
self.chunk_size = chunk_size
self.landmark_proj = nn.Linear(d_model, d_model)
def forward(self, x):
batch, seq_len, d = x.shape
# Create landmark tokens by pooling chunks
num_chunks = seq_len // self.chunk_size
chunks = x.view(batch, num_chunks, self.chunk_size, d)
landmarks = chunks.mean(dim=2) # [batch, num_chunks, d]
landmarks = self.landmark_proj(landmarks)
# Two-stage attention:
# 1. Query attends to landmarks to find relevant chunks
landmark_scores = torch.matmul(Q, landmarks.transpose(-2, -1))
top_chunks = landmark_scores.topk(k=3, dim=-1).indices # Select top-k chunks
# 2. Query attends within selected chunks
selected_chunks = chunks.gather(1, top_chunks.unsqueeze(-1).expand(-1, -1, -1, d))
chunk_attention = F.scaled_dot_product_attention(Q, selected_chunks, selected_chunks)
return chunk_attentionAdvantage: Provides random-access to any part of the context—you can jump directly to relevant chunks without processing everything in between. This is particularly powerful for tasks like question answering where you need to retrieve specific information from a long document.
Trade-off: The quality depends on how well landmarks summarize their chunks. If a landmark misses important information, that chunk might be overlooked even if it contains relevant details. The chunk size also matters—too small and you have too many landmarks, too large and landmarks become less informative.
Ring Attention (Distributed Long Context)
For truly massive contexts that don't fit on a single GPU, Ring Attention distributes the computation across multiple devices. The key insight is that attention needs queries (Q) to attend to keys and values (K, V), but we can split the sequence across GPUs and pass K/V blocks in a ring pattern.
Here's how it works: imagine you have a long sequence that you split across GPUs. Each GPU holds its local chunk of queries, keys, and values. The problem is that each query needs to attend to keys from all GPUs, not just its local ones. Ring Attention solves this by having each GPU compute attention with its local K/V, then pass those K/V blocks to the next GPU in a ring. After steps, each GPU has seen all K/V blocks from all GPUs.
The process:
- Partition: Split the sequence across GPUs—each GPU gets tokens
- Local computation: Each GPU computes attention with its local Q and local K/V
- Ring pass: Each GPU sends its K/V block to the next GPU and receives K/V from the previous GPU
- Accumulate: Each GPU accumulates attention outputs across all steps
With 8 GPUs each handling 128K tokens = 1M token effective context. The memory per GPU stays constant regardless of total context length—you're just doing more communication rounds.
Ring Attention: Circular data flow of K/V blocks across multiple GPUs
def ring_attention_step(Q_local, K_local, V_local, rank: int, world_size: int):
"""One step of ring attention (simplified)"""
# Each GPU has Q for its partition, but needs K,V from all partitions
O_local = torch.zeros_like(Q_local)
L_local = torch.zeros(Q_local.shape[:-1]) # log-sum-exp
K_recv, V_recv = K_local, V_local
for step in range(world_size):
# Compute attention with current K, V block
scores = Q_local @ K_recv.T / sqrt(d)
# Online softmax update
m_new = torch.maximum(L_local, scores.max(-1))
O_local = O_local * exp(L_local - m_new) + softmax(scores) @ V_recv
L_local = m_new
# Ring: send K,V to next GPU, receive from previous
K_recv = ring_send_recv(K_recv, rank, world_size)
V_recv = ring_send_recv(V_recv, rank, world_size)
return O_local / exp(L_local)Attention Sinks (StreamingLLM)
A surprising discovery: LLMs allocate massive attention to the first few tokens, regardless of their content. These "attention sinks" are crucial for model stability, and understanding why reveals something fundamental about how transformers work.
Unexpected Discovery: the first few tokens (often just the BOS token and maybe 2-3 more) consistently receive disproportionately high attention scores across all layers and heads. This isn't because these tokens are semantically important—it's because the model learned during training that it needs a "sink" to dump excess attention into. Without this sink, attention distributions become unstable.
Why do attention sinks exist? Attention mechanisms need a way to normalize—when a query doesn't strongly match any key, it still needs to distribute its attention somewhere. The model learned to use the initial tokens as this "somewhere," creating a stable baseline for attention distributions.
The Problem with Sliding Window
When you naively slide a window over long text, you eventually evict the initial tokens:
Window at t=0: [BOS, tok1, tok2, tok3, tok4, ...]
Window at t=1000: [tok997, tok998, tok999, tok1000, ...] ← BOS is gone!Without these sink tokens, attention scores become unstable and perplexity explodes. The model tries to redistribute attention that was going to the sinks, but there's nowhere stable to put it. Attention patterns become erratic, with tokens randomly receiving very high or very low attention, leading to poor generation quality.
The Solution: Keep the Sinks
StreamingLLM's elegant solution: keep a small number of initial "sink" tokens permanently in the KV cache, then add a sliding window for recent tokens. This gives the model both the stability anchors it needs and access to recent context:
The sink tokens act as attention anchors—they absorb excess attention and provide a stable baseline. The sliding window provides access to recent context. Together, they enable infinite streaming with stable perplexity. Remarkably, you only need about 4 sink tokens to maintain stability, regardless of how long the sequence gets.
class StreamingLLMCache:
def __init__(
self,
num_sink_tokens: int = 4,
window_size: int = 1024,
num_layers: int = 32
):
self.num_sink = num_sink_tokens
self.window_size = window_size
# KV cache structure: sink tokens + sliding window
self.sink_cache = None # [layers, 2, batch, num_sink, head_dim]
self.window_cache = None # [layers, 2, batch, window_size, head_dim]
self.window_pos = 0
def update(self, new_k: torch.Tensor, new_v: torch.Tensor, layer: int):
"""Add new KV, maintaining sink + window structure"""
if self.sink_cache is None:
# First tokens become sink tokens
self.sink_cache = ...
return
# Add to sliding window (circular buffer)
pos = self.window_pos % self.window_size
self.window_cache[layer, 0, :, pos, :] = new_k
self.window_cache[layer, 1, :, pos, :] = new_v
self.window_pos += 1
def get_kv(self, layer: int):
"""Return sink + recent window for attention"""
sink_k = self.sink_cache[layer, 0]
sink_v = self.sink_cache[layer, 1]
# Get window in correct order
if self.window_pos < self.window_size:
window_k = self.window_cache[layer, 0, :, :self.window_pos, :]
window_v = self.window_cache[layer, 1, :, :self.window_pos, :]
else:
# Reorder circular buffer
start = self.window_pos % self.window_size
window_k = torch.cat([
self.window_cache[layer, 0, :, start:, :],
self.window_cache[layer, 0, :, :start, :]
], dim=1)
window_v = torch.cat([
self.window_cache[layer, 1, :, start:, :],
self.window_cache[layer, 1, :, :start, :]
], dim=1)
return (
torch.cat([sink_k, window_k], dim=1),
torch.cat([sink_v, window_v], dim=1)
)Key insight: Only 4 sink tokens + 1K window enables infinite streaming with stable perplexity.
Used by: Many production streaming systems, forms the basis for efficient long-context serving.
H2O: Heavy-Hitter Oracle
Not all KV cache entries are equal. H2O identifies "heavy hitter" tokens—those that accumulate high attention scores—and keeps them while evicting the rest.
The Observation
In practice, attention follows a power law: a small fraction of tokens receive most of the attention. H2O exploits this:
H2O operates by continuously tracking cumulative attention scores for each token in the KV cache. As new tokens arrive, it updates these scores and identifies which tokens consistently receive high attention across multiple queries. When the cache exceeds its budget, H2O evicts tokens with low accumulated attention while preserving the recent window and top heavy hitters. This online approach adapts to the actual attention patterns during generation, making it more effective than fixed heuristics.
class H2OCache:
def __init__(self, window_size: int = 256, heavy_hitter_budget: int = 256):
self.window_size = window_size
self.hh_budget = heavy_hitter_budget
self.kv_cache = None
self.attention_accumulator = None # Track cumulative attention per token
def update(self, new_k, new_v, attention_scores):
"""Update cache with new KV and attention information"""
# Accumulate attention scores for existing tokens
if self.attention_accumulator is not None:
self.attention_accumulator += attention_scores.sum(dim=1) # Sum over queries
# Add new token
self.kv_cache = torch.cat([self.kv_cache, new_k], dim=1)
self.attention_accumulator = torch.cat([
self.attention_accumulator,
torch.zeros(1)
])
# Evict if over budget
if len(self.kv_cache) > self.window_size + self.hh_budget:
self._evict()
def _evict(self):
"""Keep recent window + top heavy hitters"""
cache_len = self.kv_cache.shape[1]
# Always keep recent window
recent_start = cache_len - self.window_size
# Find heavy hitters in older tokens
old_attention = self.attention_accumulator[:recent_start]
hh_indices = old_attention.topk(self.hh_budget).indices
# Combine: heavy hitters + recent
keep_indices = torch.cat([hh_indices, torch.arange(recent_start, cache_len)])
self.kv_cache = self.kv_cache[:, keep_indices]
self.attention_accumulator = self.attention_accumulator[keep_indices]Key insight: Heavy hitters are often semantically important tokens (subjects, key entities). Keeping them preserves critical information.
Scissorhands: Persistence-Based Eviction
Similar to H2O but uses persistence of importance—tokens that remain important across many steps are kept.
The intuition: a token attended to heavily at step 100 but ignored from step 101-500 is less important than one consistently attended to.
SnapKV: Observation Window + Pooling
SnapKV discovers that heavy hitters can be identified from just a small "observation window" at the end of the prompt, then compresses the rest:
- Run attention on last tokens to identify important positions
- Select top-k positions per head based on pooled attention
- Keep only those positions in KV cache for generation
def snapkv_compress(K, V, attention_scores, num_keep: int = 512, window: int = 64):
"""Compress KV cache using SnapKV strategy"""
seq_len = K.shape[1]
# Use attention from last 'window' queries to identify important keys
observation_attn = attention_scores[:, :, -window:, :] # [batch, heads, window, seq]
# Pool attention across observation window
importance = observation_attn.mean(dim=2) # [batch, heads, seq]
# Select top-k per head (excluding observation window itself)
prefix_importance = importance[:, :, :-window]
top_indices = prefix_importance.topk(num_keep, dim=-1).indices
# Gather compressed KV
K_compressed = K.gather(1, top_indices.unsqueeze(-1).expand(-1, -1, -1, K.shape[-1]))
V_compressed = V.gather(1, top_indices.unsqueeze(-1).expand(-1, -1, -1, V.shape[-1]))
# Append observation window (always kept)
K_out = torch.cat([K_compressed, K[:, -window:]], dim=1)
V_out = torch.cat([V_compressed, V[:, -window:]], dim=1)
return K_out, V_outAdvantage: Only needs one attention pass to identify what to keep—efficient for prompt processing.
PyramidKV: Layer-Wise Budget Allocation
Different layers need different KV cache sizes. Lower layers capture local patterns (need less cache), higher layers capture global semantics (need more).
where gives more budget to later layers.
def pyramid_budgets(num_layers: int, total_budget: int, alpha: float = 1.5):
"""Allocate KV cache budget per layer (pyramid shape)"""
# Later layers get more budget
raw_budgets = [alpha ** (num_layers - l - 1) for l in range(num_layers)]
total_raw = sum(raw_budgets)
# Normalize to total budget
budgets = [int(b / total_raw * total_budget) for b in raw_budgets]
return budgets # e.g., [64, 96, 144, 216, 324, 486, ...] for 7 layersFastGen: Adaptive KV Compression
FastGen takes an adaptive approach to KV cache compression by profiling attention patterns during prompt processing and applying learned compression strategies. The method first analyzes attention entropy and variance across heads to determine per-head compression ratios—heads with low entropy (focused attention) can be compressed more aggressively, while heads with high entropy (distributed attention) need more context preserved.
Based on this profiling, FastGen applies multiple adaptive strategies:
- Special tokens: Always kept since they receive high attention
- Punctuation tokens: Often dropped as they're less critical
- Content tokens: Selectively pruned based on their attention profiles
The key innovation is that compression ratios vary per attention head, recognizing that some heads are naturally compressible while others need full context. This adaptive approach allows FastGen to achieve better compression-quality trade-offs compared to fixed heuristics that treat all tokens and heads equally.
Comparison of Streaming Methods
| Method | What to Keep | When to Decide | Memory | Overhead |
|---|---|---|---|---|
| Attention Sinks | First tokens + recent | Fixed | O(w+k) | None |
| H2O | Heavy hitters + recent | Online every step | O(w+k) | Low |
| Scissorhands | Persistent hitters + recent | Online | O(w+k) | Low |
| SnapKV | Important observed + recent | Once at prompt end | O(k) | Low |
| PyramidKV | Per-layer budgets | Once | O(k*L) varying | None |
| FastGen | Adaptive per-head | Once | O(k) varying | Medium |
The trend: Move from fixed heuristics (Attention Sinks) to learned/adaptive selection (SnapKV, FastGen) that can handle diverse workloads.
Infini-Attention (Google's 1M+ Method)
The technique behind Gemini's massive context. Combines local attention with a compressive memory that summarizes discarded history.
The Core Idea
Instead of throwing away old KV pairs, compress them into a fixed-size memory:
where is a compressive memory updated incrementally.
Infini-Attention processes sequences in segments. For each segment, it performs two parallel attention operations: standard causal attention over the current segment's K/V pairs, and linear attention retrieval from the compressive memory that stores compressed information from all previous segments. The outputs are combined using a learnable gate that controls how much to rely on local vs. memory attention. This hybrid approach gives you both the precision of standard attention for recent context and the efficiency of compressed memory for distant history.
Infini-Attention architecture: hybrid of local causal attention and compressive memory
Memory as Associative Binding
Infini-attention uses linear attention to maintain memory. The key insight is that memory can be updated incrementally by accumulating compressed KV information:
This update rule means that as you process each segment, you compress its K/V pairs and add them to the memory matrix. The memory has fixed size regardless of how many segments you've processed—it's a compressed summary of all history.
When retrieving information, queries attend to the memory using linear attention:
where is a normalization term that tracks the total "mass" of keys added to memory, and is a non-linearity (e.g., ELU + 1) that ensures positive attention weights. This linear attention formulation avoids the quadratic cost of standard attention while still allowing queries to retrieve relevant information from the compressed memory.
class InfiniAttention(nn.Module):
def __init__(self, d_model: int, num_heads: int, segment_len: int = 2048):
super().__init__()
self.d_model = d_model
self.num_heads = num_heads
self.head_dim = d_model // num_heads
self.segment_len = segment_len
self.W_q = nn.Linear(d_model, d_model)
self.W_k = nn.Linear(d_model, d_model)
self.W_v = nn.Linear(d_model, d_model)
self.W_o = nn.Linear(d_model, d_model)
# Learnable gate for combining local + memory attention
self.beta = nn.Parameter(torch.zeros(num_heads))
def forward(self, x, memory_state=None):
batch, seq_len, _ = x.shape
Q = self.W_q(x).view(batch, seq_len, self.num_heads, self.head_dim)
K = self.W_k(x).view(batch, seq_len, self.num_heads, self.head_dim)
V = self.W_v(x).view(batch, seq_len, self.num_heads, self.head_dim)
# Initialize memory if needed
if memory_state is None:
M = torch.zeros(batch, self.num_heads, self.head_dim, self.head_dim)
z = torch.zeros(batch, self.num_heads, self.head_dim)
else:
M, z = memory_state
outputs = []
# Process in segments
for seg_start in range(0, seq_len, self.segment_len):
seg_end = min(seg_start + self.segment_len, seq_len)
Q_seg = Q[:, seg_start:seg_end]
K_seg = K[:, seg_start:seg_end]
V_seg = V[:, seg_start:seg_end]
# === Local causal attention ===
local_out = F.scaled_dot_product_attention(
Q_seg.transpose(1, 2),
K_seg.transpose(1, 2),
V_seg.transpose(1, 2),
is_causal=True
).transpose(1, 2)
# === Memory retrieval ===
# σ(Q) @ M / (σ(Q) @ z)
Q_norm = F.elu(Q_seg) + 1 # [batch, seg_len, heads, head_dim]
# Retrieve from memory
mem_out = torch.einsum('bshd,bhde->bshe', Q_norm, M)
normalizer = torch.einsum('bshd,bhd->bsh', Q_norm, z).unsqueeze(-1) + 1e-6
mem_out = mem_out / normalizer
# === Combine with learnable gate ===
beta = torch.sigmoid(self.beta).view(1, 1, -1, 1)
combined = local_out + beta * mem_out
outputs.append(combined)
# === Update memory with this segment ===
K_norm = F.elu(K_seg) + 1
# M += σ(K)^T @ V
M = M + torch.einsum('bshd,bshe->bhde', K_norm, V_seg)
# z += sum(σ(K))
z = z + K_norm.sum(dim=1)
output = torch.cat(outputs, dim=1)
output = output.reshape(batch, seq_len, self.d_model)
return self.W_o(output), (M, z)Why It Scales to Millions of Tokens
| Component | Memory | Compute |
|---|---|---|
| Local attention (segment) | ||
| Memory retrieval | fixed | |
| Memory update | fixed |
The compressive memory has fixed size regardless of context length. Processing 1M tokens only requires linear compute .
Results: Infini-attention achieves:
- 1M token passkey retrieval with 1B parameter model
- 114x compression ratio vs. baseline memory
- Comparable quality to full attention on BookSum (summarizing 500K+ token books)
LongRoPE: 2M Context Extension
Microsoft's approach to extreme length extension via progressive interpolation:
- Search for optimal RoPE rescaling factors per dimension
- Progressive extension: 256K → 512K → 1M → 2M in stages
- Readjust for short contexts: Prevent degradation on original lengths
def longrope_scaling(dim: int, target_len: int, original_len: int = 4096):
"""LongRoPE non-uniform scaling factors"""
# Different dimensions get different scaling
# Low-frequency (high dim indices) scale more aggressively
scale_factor = target_len / original_len
# Searched optimal factors (simplified)
lambda_factors = torch.ones(dim // 2)
# High-frequency dimensions (local info): minimal scaling
lambda_factors[:dim//4] = 1.0
# Low-frequency dimensions (global info): aggressive scaling
lambda_factors[dim//4:] = scale_factor ** 0.5
return lambda_factorsPart 3: Memory Systems (Beyond Context Windows)
What if we stop pretending everything fits in context?
Retrieval-Augmented Generation (RAG)
RAG has become super common these days, so I'll skip the basics. The core idea is simple: instead of putting everything in context, retrieve only what's relevant when you need it.
- Index: Embed documents into vector database
- Retrieve: Find top-k relevant chunks for query
- Generate: Use retrieved chunks as context
class RAGSystem:
def __init__(self, embedding_model, vector_db, llm):
self.embedder = embedding_model
self.db = vector_db
self.llm = llm
def query(self, question: str, k: int = 5):
# Embed the question
q_embedding = self.embedder.encode(question)
# Retrieve relevant chunks
chunks = self.db.search(q_embedding, top_k=k)
# Build context
context = "\n\n".join([c.text for c in chunks])
# Generate answer
prompt = f"Context:\n{context}\n\nQuestion: {question}\nAnswer:"
return self.llm.generate(prompt)Effective context: Unlimited (database can hold terabytes)
Limitation: Retrieval quality is crucial. Complex queries spanning multiple documents are hard.
Memorizing Transformers
Add an explicit kNN memory to attention:
where is an external memory of past (key, value) pairs.
Memorizing Transformer architecture: combining local attention with kNN retrieval from external memory
class MemorizingAttention(nn.Module):
def __init__(self, d_model: int, memory_size: int = 65536):
super().__init__()
self.memory_keys = torch.zeros(memory_size, d_model)
self.memory_values = torch.zeros(memory_size, d_model)
self.memory_ptr = 0
self.memory_size = memory_size
def forward(self, Q, K, V):
# Standard local attention
local_attn = F.scaled_dot_product_attention(Q, K, V)
# kNN lookup in memory
# Find top-k most similar memory keys for each query
similarities = Q @ self.memory_keys.T # [batch, seq, memory_size]
top_k_idx = similarities.topk(k=32, dim=-1).indices
# Gather memory values
memory_v = self.memory_values[top_k_idx]
memory_attn = F.softmax(similarities.gather(-1, top_k_idx), dim=-1)
memory_out = (memory_attn.unsqueeze(-1) * memory_v).sum(-2)
# Combine
return local_attn + 0.1 * memory_out
def update_memory(self, K, V):
"""Add current K, V to memory"""
batch_size = K.shape[0] * K.shape[1]
end_ptr = (self.memory_ptr + batch_size) % self.memory_size
# FIFO update
self.memory_keys[self.memory_ptr:end_ptr] = K.flatten(0, 1)
self.memory_values[self.memory_ptr:end_ptr] = V.flatten(0, 1)
self.memory_ptr = end_ptrPart 4: Context Compression
Instead of extending context, compress it.
Gisting / Prompt Compression
Learn special "gist" tokens that summarize long contexts:
class GistCompressor(nn.Module):
def __init__(self, llm, num_gist_tokens: int = 10):
super().__init__()
self.llm = llm
self.gist_tokens = nn.Parameter(torch.randn(num_gist_tokens, llm.d_model))
def compress(self, long_context_ids: torch.Tensor):
# Encode the long context
hidden_states = self.llm.encode(long_context_ids)
# Cross-attention: gist tokens attend to context
gist_repr = self.cross_attention(
query=self.gist_tokens,
key=hidden_states,
value=hidden_states
)
return gist_repr # [num_gist_tokens, d_model]AutoCompressor: Train models to recursively summarize their own context.
ICAE: In-Context Autoencoder
Train an encoder to compress context, decoder to expand when needed:
Achieves ~30x compression with minimal quality loss.
LongLLMLingua: Selective Pruning
Not all tokens matter equally. Prune unimportant ones:
- Score each token by perplexity contribution
- Keep tokens with high information content
- Compress 10K tokens to 2K with minimal information loss
def compress_context(context_ids: torch.Tensor, model, target_ratio: float = 0.3):
"""Compress context by keeping only important tokens"""
# Get token importance scores
with torch.no_grad():
outputs = model(context_ids, output_attentions=True)
# Aggregate attention across layers and heads
importance = torch.stack(outputs.attentions).mean(dim=(0, 2)) # [seq, seq]
token_importance = importance.sum(dim=0) # How much each token is attended to
# Keep top tokens
num_keep = int(len(context_ids) * target_ratio)
keep_idx = token_importance.topk(num_keep).indices.sort().values
return context_ids[keep_idx]Part 5: Recurrent and Streaming Approaches
Making recurrent methods perform better than transformers is another fascinating direction. The goal here is to process infinite streams by maintaining fixed-size state—a fundamentally different approach from attention-based methods. I'll do a deep dive on these architectures soon, but for now, here are the basics.
RWKV: Linear RNN with Attention-like Expressivity
RWKV (Receptance Weighted Key Value) replaces attention with a linear recurrence that can express similar patterns. The key innovation is that it maintains a running weighted sum of values, where the weights decay exponentially over time.
The recurrence has two components: a decay mechanism that controls how much history to remember, and a weighted aggregation that combines past values:
This decay term determines how much weight to give to tokens at different positions. The exponential decay means recent tokens get more weight, but older tokens aren't completely forgotten—they're just weighted less.
The output combines this decay with key-value attention:
Here, is a "receptance" gate that controls how much of the aggregated information to use, are keys that determine relevance, and are the values being aggregated. The exponential terms create attention-like weights, but the recurrence structure allows efficient computation.
Key insight: This formulation has a remarkable property—it can be computed as an RNN with time per token during inference (just update the running sums), but during training, it can be parallelized like standard attention. This gives you the best of both worlds: efficient streaming inference and fast parallel training. The model learns to use the decay mechanism to focus on relevant history while maintaining linear complexity.
Mamba / State Space Models
Maintain a fixed-size hidden state that summarizes all history:
Effective context: Unlimited (state compresses all history)
Limitation: Fixed state size means lossy compression. Some information is forgotten.
Transformer-XL: Segment-Level Recurrence
Process context in segments, passing hidden states between segments:
Segment 1: [tokens 0-512] → hidden_1
Segment 2: [tokens 512-1024] → hidden_2 (conditioned on hidden_1)
Segment 3: [tokens 1024-1536] → hidden_3 (conditioned on hidden_2)class TransformerXL(nn.Module):
def __init__(self, d_model: int, segment_len: int = 512):
super().__init__()
self.segment_len = segment_len
self.layers = nn.ModuleList([...])
def forward_segment(self, x, memory=None):
"""Process one segment with memory from previous segment"""
for layer in self.layers:
if memory is not None:
# Concatenate memory for extended context
k = torch.cat([memory, x], dim=1)
v = torch.cat([memory, x], dim=1)
else:
k, v = x, x
x = layer(x, k, v)
return x
def forward_stream(self, token_stream):
"""Process infinite stream of tokens"""
memory = None
for segment in chunk(token_stream, self.segment_len):
output = self.forward_segment(segment, memory)
memory = output.detach() # Detach to prevent infinite backprop
yield outputComparison: What Works When?
| Method | Effective Length | Latency | Use Case |
|---|---|---|---|
| RoPE + YaRN | ~128K | Medium | General long-context |
| LongRoPE | ~2M | Medium | Extreme length extension |
| Sliding Window | Unlimited* | Fast | Streaming, local tasks |
| Attention Sinks | Unlimited* | Very fast | Streaming inference |
| Ring Attention | ~1M+ | High (distributed) | Training on very long docs |
| Infini-Attention | ~1M+ | Medium | Production 1M+ contexts |
| RAG | Unlimited | Medium | Knowledge-intensive tasks |
| RWKV/Mamba | Unlimited | Very fast | Efficiency-critical |
| Compression | 10-30x ratio | Fast | Prompt optimization |
* Local attention quality, global degraded
** Depends on retrieval quality
Some Final Thoughts
- 128K native context is "solved": YaRN + FlashAttention + GQA make this practical
- 1M context is achievable: Infini-attention + Ring attention make this real (Gemini proves it)
- Streaming has multiple good solutions: From simple (Attention Sinks) to adaptive (SnapKV, H2O)—pick based on your latency/quality tradeoff
- KV cache compression is the key bottleneck: Most streaming methods are really about which tokens to keep
- True "infinite" context requires hybrid approaches: Native context + compressive memory + retrieval
- Quality degrades gracefully: Even with 10M tokens, models struggle with "needle in haystack" tasks—retrieval within context remains hard
References
-
Su, J., Lu, Y., Pan, S., Murtadha, A., Wen, B., & Liu, Y. (2021). RoFormer: Enhanced Transformer with Rotary Position Embedding. arXiv preprint. arXiv:2104.09864
-
Press, O., Smith, N. A., & Lewis, M. (2022). Train Short, Test Long: Attention with Linear Biases Enables Input Length Extrapolation. ICLR. arXiv:2108.12409
-
Chen, S., Wong, S., Chen, L., & Tian, Y. (2023). Extending Context Window of Large Language Models via Positional Interpolation. arXiv preprint. arXiv:2306.15595
-
Peng, B., Quesnelle, J., Fan, H., & Shippole, E. (2023). YaRN: Efficient Context Window Extension of Large Language Models. arXiv preprint. arXiv:2309.00071
-
Liu, H., Yan, W., Zaharia, M., & Abbeel, P. (2023). Ring Attention with Blockwise Transformers for Near-Infinite Context. arXiv preprint. arXiv:2310.01889
-
Mohtashami, A., & Jaggi, M. (2023). Landmark Attention: Random-Access Infinite Context Length for Transformers. arXiv preprint. arXiv:2305.16300
-
Wu, Y., Rabe, M. N., Hutchins, D., & Szegedy, C. (2022). Memorizing Transformers. ICLR. arXiv:2203.08913
-
Mu, J., Li, X., & Goodman, N. D. (2023). Learning to Compress Prompts with Gist Tokens. NeurIPS. arXiv:2304.08467
-
Ge, T., Hu, J., Wang, X., Chen, S., & Wei, F. (2024). In-context Autoencoder for Context Compression in a Large Language Model. ICLR. arXiv:2307.06945
-
Jiang, H., Wu, Q., Lin, C. Y., Yang, Y., & Qiu, L. (2023). LongLLMLingua: Accelerating and Enhancing LLMs in Long Context Scenarios via Prompt Compression. arXiv preprint. arXiv:2310.06839
-
Peng, B., Alcaide, E., Anthony, Q., et al. (2023). RWKV: Reinventing RNNs for the Transformer Era. EMNLP. arXiv:2305.13048
-
Gu, A., & Dao, T. (2023). Mamba: Linear-Time Sequence Modeling with Selective State Spaces. arXiv preprint. arXiv:2312.00752
-
Dai, Z., Yang, Z., Yang, Y., Carbonell, J., Le, Q. V., & Salakhutdinov, R. (2019). Transformer-XL: Attentive Language Models Beyond a Fixed-Length Context. ACL. arXiv:1901.02860
-
Reid, M., Savinov, N., Teber, D., et al. (2024). Gemini 1.5: Unlocking multimodal understanding across millions of tokens of context. arXiv preprint. arXiv:2403.05530
-
Bertsch, A., Alon, U., Neubig, G., & Gormley, M. R. (2024). Unlimiformer: Long-Range Transformers with Unlimited Length Input. NeurIPS. arXiv:2305.01625
-
Xiao, G., Tian, Y., Chen, B., Han, S., & Lewis, M. (2023). Efficient Streaming Language Models with Attention Sinks. ICLR. arXiv:2309.17453
-
Munkhdalai, T., Faruqui, M., & Gopal, S. (2024). Leave No Context Behind: Efficient Infinite Context Transformers with Infini-attention. arXiv preprint. arXiv:2404.07143
-
Ding, Y., Zhang, L., Shang, J., Xu, J., et al. (2024). LongRoPE: Extending LLM Context Window Beyond 2 Million Tokens. arXiv preprint. arXiv:2402.13753
-
Han, C., Wang, Q., Xiong, W., et al. (2024). LM-Infinite: Simple On-the-Fly Length Generalization for Large Language Models. NAACL. arXiv:2308.16137
-
Zhang, Z., Sheng, Y., Zhou, T., Chen, T., et al. (2024). H2O: Heavy-Hitter Oracle for Efficient Generative Inference of Large Language Models. NeurIPS. arXiv:2306.14048
-
Liu, Z., Desai, A., Liao, F., Wang, W., Xie, V., Xu, Z., Kyrillidis, A., & Shrivastava, A. (2023). Scissorhands: Exploiting the Persistence of Importance Hypothesis for LLM KV Cache Compression at Test Time. NeurIPS. arXiv:2305.17118
-
Li, Y., He, Y., Sun, Y., Tan, Z., Yan, G., et al. (2024). SnapKV: LLM Knows What You are Looking for Before Generation. arXiv preprint. arXiv:2404.14469
-
Cai, Z., Zhang, Y., Gao, B., Liu, Y., et al. (2024). PyramidKV: Dynamic KV Cache Compression based on Pyramidal Information Funneling. arXiv preprint. arXiv:2406.02069
-
Ge, S., Zhang, Y., Liu, L., Zhang, M., Han, J., & Gao, J. (2024). Model Tells You What to Discard: Adaptive KV Cache Compression for LLMs. ICLR. arXiv:2310.01801
-
Sheng, Y., Zheng, L., Yuan, B., Li, Z., et al. (2023). FlexGen: High-Throughput Generative Inference of Large Language Models with a Single GPU. ICML. arXiv:2303.06865
-
Hooper, C., Kim, S., Mohammadzadeh, H., et al. (2024). KVQuant: Towards 10 Million Context Length LLM Inference with KV Cache Quantization. arXiv preprint. arXiv:2401.18079