Attention That You Probably Didnt Know Existed!!
From sparse patterns to linear attention and state space models exploring the zoo of efficient attention mechanisms that go beyond vanilla transformers.
Table of Contents
- The Quadratic Bottleneck
- Approach 1: Sparse Attention Patterns
- Fixed Patterns (Longformer, BigBird)
- Learned Sparse Patterns (Routing Transformer, Reformer)
- Approach 2: Linear Attention (Kernel Methods)
- The Kernel Trick
- Popular Linear Attention Variants
- Performer: Random Feature Maps
- Approach 3: Low-Rank Approximation
- Linformer
- Nyströmformer
- Approach 4: State Space Models (Mamba, S4)
- The State Space Formulation
- Mamba: Selective State Spaces
- Approach 5: IO-Aware Algorithms (FlashAttention)
- The Memory Hierarchy Insight
- Approach 6: Grouped-Query Attention (GQA)
- The KV Cache Problem
- Multi-Query Attention (MQA)
- Grouped-Query Attention: The Sweet Spot
- The Hybrid Future
- Sliding Window + Global (Mistral, Gemma 2)
- Attention + SSM (Jamba, Zamba)
- Comparison: What Actually Works?
- My Take: The Quadratic Wall Is Softer Than It Looks
- References
I've been looking at self attention since 2019, since then a lof of variants of self-attention has been proposed. In this article, I breakdown efficient attention mechanisms: linear attention, sparse patterns, low-rank approximations, grouped attention and a lot of other attention as well. Time to pay attention to self-attention xD.
The Quadratic Bottleneck
Standard self-attention computes:
The matrix has shape . For a 100K token sequence with 128 attention heads across 80 layers, you're looking at:
Obviously impossible. Even at 4K tokens, this dominates memory usage. The challenge: can we approximate attention without materializing the full matrix?
Approach 1: Sparse Attention Patterns
The simplest idea: don't compute all attention scores. Only attend to a subset of positions.
Fixed Patterns (Longformer, BigBird)
Define a fixed sparsity pattern combining:
- Local attention: Each token attends to neighbors
- Global attention: Special tokens attend to everything
- Random attention: Sparse random connections
def create_longformer_mask(seq_len: int, window_size: int, global_tokens: int):
"""
Creates Longformer-style attention mask
Returns: [seq_len, seq_len] boolean mask
"""
mask = torch.zeros(seq_len, seq_len, dtype=torch.bool)
# Local sliding window
for i in range(seq_len):
start = max(0, i - window_size // 2)
end = min(seq_len, i + window_size // 2 + 1)
mask[i, start:end] = True
# Global tokens (first k tokens attend to/from everything)
mask[:global_tokens, :] = True
mask[:, :global_tokens] = True
return maskComplexity: where is window size—linear in sequence length!
Attention patterns: Full n² attention, Sliding window, Dilated sliding window, Global+sliding window
The catch: Fixed patterns assume locality matters. For tasks requiring long-range reasoning, you're hoping information propagates through the sparse graph. Sometimes it doesn't.
Learned Sparse Patterns (Routing Transformer, Reformer)
Let the model learn which tokens to attend to:
Reformer uses Locality-Sensitive Hashing (LSH):
Tokens with the same hash are grouped into buckets and only attend within buckets.
def lsh_attention(Q, K, V, num_hashes: int = 4, num_buckets: int = 32):
"""Simplified LSH attention (conceptual)"""
batch, seq_len, d = Q.shape
# Generate random projection matrix
R = torch.randn(d, num_buckets // 2)
# Hash queries and keys
projections = Q @ R # [batch, seq_len, num_buckets//2]
hashes = torch.argmax(
torch.cat([projections, -projections], dim=-1), dim=-1
)
# Sort by hash to group similar tokens
sorted_indices = torch.argsort(hashes, dim=1)
Q_sorted = torch.gather(Q, 1, sorted_indices.unsqueeze(-1).expand_as(Q))
K_sorted = torch.gather(K, 1, sorted_indices.unsqueeze(-1).expand_as(K))
V_sorted = torch.gather(V, 1, sorted_indices.unsqueeze(-1).expand_as(V))
# Attend within chunks (buckets)
chunk_size = seq_len // num_buckets
# ... chunked attention within each bucket
return outputComplexity: for LSH bucketing + for within-bucket attention.
The catch: Hash collisions are probabilistic. Similar tokens might end up in different buckets. You need multiple hash rounds to reduce this, eating into efficiency gains.
Approach 2: Linear Attention (Kernel Methods)
The most mathematically elegant approach: reformulate attention to avoid the explicit matrix.
The Kernel Trick
Standard softmax attention:
Replace the exponential kernel with a decomposable feature map :
Now attention becomes:
The key insight: and can be computed once and reused for all queries!
def linear_attention(Q, K, V, feature_map=None):
"""
Linear attention via kernel feature maps
Complexity: O(n * d^2) instead of O(n^2 * d)
"""
if feature_map is None:
# ELU-based feature map (from "Transformers are RNNs")
feature_map = lambda x: F.elu(x) + 1
Q = feature_map(Q) # [batch, seq, d]
K = feature_map(K) # [batch, seq, d]
# Compute KV summary: [batch, d, d_v]
KV = torch.einsum('bnd,bnv->bdv', K, V)
# Compute K summary for normalization: [batch, d]
K_sum = K.sum(dim=1)
# Compute attention output
numerator = torch.einsum('bnd,bdv->bnv', Q, KV)
denominator = torch.einsum('bnd,bd->bn', Q, K_sum).unsqueeze(-1)
return numerator / (denominator + 1e-6)Complexity: —linear in sequence length!
Popular Linear Attention Variants
| Method | Feature Map | Notes |
|---|---|---|
| Linear Transformer | Simple, stable | |
| Performer | Random Fourier Features | Unbiased softmax approximation |
| cosFormer | Re-weighting for locality | |
| RWKV | Exponential decay | RNN-like, state-based |
Performer: Random Feature Maps
Performer provides an unbiased estimator of softmax attention using random Fourier features:
def random_fourier_features(x, num_features: int = 256):
"""
FAVOR+ feature map from Performer
"""
d = x.shape[-1]
# Sample random projections (orthogonal for lower variance)
W = torch.randn(num_features, d) / (d ** 0.5)
W = torch.linalg.qr(W.T)[0].T[:num_features] # Orthogonalize
# Compute features
proj = x @ W.T # [batch, seq, num_features]
# Normalize for unbiased estimate
norm_factor = (x ** 2).sum(dim=-1, keepdim=True) / 2
features = torch.exp(proj - norm_factor) / (num_features ** 0.5)
return featuresThe catch with linear attention: It's not actually approximating softmax well for all distributions. The "softmax" property of sharp, selective attention is lost. For many tasks, quality degrades significantly.
Approach 3: Low-Rank Approximation (Linformer, Nyströmformer)
If the attention matrix has low intrinsic rank, project it to a smaller space.
Linformer
Project keys and values to a fixed dimension :
where are learned projection matrices.
class LinformerAttention(nn.Module):
def __init__(self, d_model: int, seq_len: int, k: int = 256):
super().__init__()
self.k = k
# Projection matrices
self.E = nn.Parameter(torch.randn(k, seq_len) / seq_len ** 0.5)
self.F = nn.Parameter(torch.randn(k, seq_len) / seq_len ** 0.5)
self.W_q = nn.Linear(d_model, d_model)
self.W_k = nn.Linear(d_model, d_model)
self.W_v = nn.Linear(d_model, d_model)
def forward(self, x):
Q = self.W_q(x) # [batch, n, d]
K = self.W_k(x) # [batch, n, d]
V = self.W_v(x) # [batch, n, d]
# Project K and V: [batch, k, d]
K_proj = torch.einsum('kn,bnd->bkd', self.E, K)
V_proj = torch.einsum('kn,bnd->bkd', self.F, V)
# Standard attention on reduced dimensions
# [batch, n, d] @ [batch, d, k] = [batch, n, k]
scores = torch.matmul(Q, K_proj.transpose(-2, -1)) / (Q.shape[-1] ** 0.5)
attn = F.softmax(scores, dim=-1)
# [batch, n, k] @ [batch, k, d] = [batch, n, d]
return torch.matmul(attn, V_proj)Complexity: where is the projection dimension.
The catch: Fixed projection dimension means you're betting the attention matrix is always low-rank. This holds for some tasks but not others. The projection also couples sequence length to architecture—you can't easily handle variable lengths.
Nyströmformer
Uses the Nyström method to approximate the full attention matrix from a subset of "landmark" tokens:
where are landmark keys selected via pooling.
Approach 4: State Space Models (Mamba, S4)
The radical approach: abandon attention entirely. Replace it with structured state space models that process sequences in time.
The State Space Formulation
A continuous-time state space model:
Discretize and unroll:
This is an RNN! But with special structure on that enables:
- Parallel training via convolution
- Linear-time inference via recurrence
Mamba: Selective State Spaces
Mamba makes the state space parameters input-dependent:
This selectivity lets the model dynamically decide what to remember and forget—recovering some of the "content-based" addressing that attention provides.
class SimplifiedMambaBlock(nn.Module):
"""Conceptual Mamba block (actual impl uses CUDA kernels)"""
def __init__(self, d_model: int, d_state: int = 16, d_conv: int = 4):
super().__init__()
self.d_model = d_model
self.d_state = d_state
# SSM parameters
self.A_log = nn.Parameter(torch.randn(d_model, d_state))
# Input-dependent projections
self.x_proj = nn.Linear(d_model, d_state * 2 + 1) # B, C, delta
# Convolution for local context
self.conv = nn.Conv1d(d_model, d_model, d_conv, padding=d_conv-1, groups=d_model)
def forward(self, x):
batch, seq_len, d = x.shape
# Local convolution
x_conv = self.conv(x.transpose(1, 2))[:, :, :seq_len].transpose(1, 2)
x_conv = F.silu(x_conv)
# Compute input-dependent SSM parameters
x_dbl = self.x_proj(x_conv)
B = x_dbl[..., :self.d_state]
C = x_dbl[..., self.d_state:2*self.d_state]
delta = F.softplus(x_dbl[..., -1:])
# Discretize A
A = -torch.exp(self.A_log)
A_bar = torch.exp(delta * A)
# Recurrent scan (parallel via associative scan in practice)
h = torch.zeros(batch, d, self.d_state, device=x.device)
outputs = []
for t in range(seq_len):
h = A_bar[:, t:t+1] * h + delta[:, t:t+1] * B[:, t:t+1].unsqueeze(1) * x_conv[:, t:t+1].unsqueeze(-1)
y_t = (h * C[:, t:t+1].unsqueeze(1)).sum(-1)
outputs.append(y_t)
return torch.cat(outputs, dim=1)Complexity: time and memory!
The catch: SSMs trade the quadratic cost for a fixed-size hidden state. Long-range dependencies must be compressed into this state. For tasks requiring precise retrieval over long contexts (like "find the needle in the haystack"), SSMs struggle compared to attention.
Approach 5: IO-Aware Algorithms (FlashAttention)
Plot twist: maybe the algorithm isn't the problem—the implementation is.
The Memory Hierarchy Insight
Standard attention is memory-bound, not compute-bound. The attention matrix gets written to slow HBM (GPU main memory), then read back for softmax and output computation.
FlashAttention keeps everything in fast SRAM by:
- Tiling the computation into blocks
- Recomputing during backward pass instead of storing
- Fusing operations to minimize memory transfers
Standard Attention Memory Access:
Q, K, V (HBM) → Compute QK^T → Store S (HBM) → Load S → Softmax → Store P (HBM) → Load P, V → Output
FlashAttention Memory Access:
Q, K, V (HBM) → [Tiled compute entirely in SRAM] → Output (HBM)Complexity: Still compute, but memory and 2-4x faster due to reduced memory I/O.
The insight: FlashAttention doesn't reduce computational complexity—it makes the quadratic algorithm practical by respecting the memory hierarchy. For many use cases, this is enough.
Approach 6: Grouped-Query Attention (GQA)
Here's an approach you might not have heard about that's quietly powering most modern LLMs: what if the bottleneck isn't computing attention, but storing the key-value cache?
The KV Cache Problem
During autoregressive generation, we cache key and value tensors from previous tokens to avoid recomputation. With standard Multi-Head Attention (MHA), each head has its own K and V projections:
For a 70B model with 80 layers and 64 heads, generating 8K tokens requires ~40GB just for KV cache. That's often more memory than the model weights themselves!
Multi-Query Attention (MQA)
The radical simplification: all query heads share a single set of keys and values.
Each query head attends using the same K and V:
KV cache reduction: smaller (e.g., 64× for 64-head models)!
The catch: Quality degrades. Different heads should learn different attention patterns, but they're forced to share the same keys and values.
Grouped-Query Attention (GQA): The Sweet Spot
GQA is the Goldilocks solution: instead of 1 KV set (MQA) or KV sets (MHA), use groups where .
Each group of query heads shares one set of keys and values:
Multi-Head, Grouped-Query, and Multi-Query Attention: Trading KV cache size for quality
class GroupedQueryAttention(nn.Module):
def __init__(self, d_model: int, n_heads: int, n_kv_heads: int):
"""
n_heads: number of query heads
n_kv_heads: number of key-value heads (groups)
"""
super().__init__()
self.n_heads = n_heads
self.n_kv_heads = n_kv_heads
self.n_rep = n_heads // n_kv_heads # queries per KV group
self.head_dim = d_model // n_heads
self.W_q = nn.Linear(d_model, n_heads * self.head_dim)
self.W_k = nn.Linear(d_model, n_kv_heads * self.head_dim)
self.W_v = nn.Linear(d_model, n_kv_heads * self.head_dim)
self.W_o = nn.Linear(d_model, d_model)
def forward(self, x):
batch, seq_len, _ = x.shape
# Project
Q = self.W_q(x).view(batch, seq_len, self.n_heads, self.head_dim)
K = self.W_k(x).view(batch, seq_len, self.n_kv_heads, self.head_dim)
V = self.W_v(x).view(batch, seq_len, self.n_kv_heads, self.head_dim)
# Repeat K, V to match query heads
# [batch, seq, n_kv_heads, head_dim] -> [batch, seq, n_heads, head_dim]
K = K.repeat_interleave(self.n_rep, dim=2)
V = V.repeat_interleave(self.n_rep, dim=2)
# Standard attention
Q, K, V = Q.transpose(1, 2), K.transpose(1, 2), V.transpose(1, 2)
scores = torch.matmul(Q, K.transpose(-2, -1)) / (self.head_dim ** 0.5)
attn = F.softmax(scores, dim=-1)
out = torch.matmul(attn, V)
out = out.transpose(1, 2).reshape(batch, seq_len, -1)
return self.W_o(out)The Math: Memory Savings
| Attention Type | KV Heads | KV Cache Size | Quality |
|---|---|---|---|
| Multi-Head (MHA) | 100% (baseline) | ||
| Grouped-Query (GQA) | ~99% | ||
| Multi-Query (MQA) | ~95-97% |
Llama 2 70B uses GQA with 8 KV heads and 64 query heads—an 8× reduction in KV cache with minimal quality loss.
Why this matters: GQA doesn't reduce computational complexity (still ), but it dramatically reduces memory bandwidth during generation. When you're generating tokens one at a time, memory bandwidth for loading the KV cache is often the bottleneck. GQA makes inference 2-4× faster in practice.
Used by: Llama 2/3, Mistral, Gemma, most modern open-source LLMs.
The Hybrid Future
No single approach dominates. The trend is hybrid architectures:
Sliding Window + Global (Mistral, Gemma 2)
Layer 1-4: Sliding window attention (local)
Layer 5: Global attention (every 4th layer)
Layer 6-9: Sliding window attention (local)
Layer 10: Global attention
...This gets for most layers while preserving some global context mixing.
Attention + SSM (Jamba, Zamba)
Interleave attention layers with Mamba layers:
- Mamba layers: Efficient long-range propagation
- Attention layers: Precise retrieval when needed
class HybridLayer(nn.Module):
def __init__(self, d_model: int, use_attention: bool):
super().__init__()
if use_attention:
self.mixer = MultiHeadAttention(d_model, num_heads=8)
else:
self.mixer = MambaBlock(d_model)
self.norm = nn.LayerNorm(d_model)
self.ffn = FeedForward(d_model)
def forward(self, x):
x = x + self.mixer(self.norm(x))
x = x + self.ffn(self.norm(x))
return x
# Example: 1 attention layer per 4 Mamba layers
layers = []
for i in range(32):
use_attention = (i % 4 == 0)
layers.append(HybridLayer(d_model=2048, use_attention=use_attention))Comparison: What Actually Works?
| Method | Complexity | Quality vs Softmax | Training Stable | Long-Range |
|---|---|---|---|---|
| Full Attention | ✓ Exact | ✓ Yes | ✓ Excellent | |
| FlashAttention | ✓ Exact | ✓ Yes | ✓ Excellent | |
| GQA | ~99% | ✓ Yes | ✓ Excellent | |
| Sparse (Longformer) | ~95% | ✓ Yes | ~ Depends on pattern | |
| Linear (Performer) | ~85-90% | ~ Sometimes | ✗ Degraded | |
| Low-rank (Linformer) | ~90% | ✓ Yes | ~ Limited | |
| SSM (Mamba) | ~90-95% | ✓ Yes | ~ Good, not perfect | |
| Hybrid | ~97-99% | ✓ Yes | ✓ Very good |
The uncomfortable truth: FlashAttention + sliding window hybrids currently win for most practical applications. Pure linear attention and SSMs trade too much quality for theoretical efficiency gains.
Final Thoughts: The Quadratic Wall Is Softer Than It Looks
- FlashAttention changed the game: compute is fine if memory is
- Hybrid is the answer: Mix efficient local attention with sparse global attention
- SSMs are complements, not replacements: Great for efficiency, but attention still wins for retrieval
- Hardware will catch up: Future GPUs with more SRAM might make the whole debate moot
The quest for attention isn't over, but the practical ceiling has moved from 4K to 1M+ tokens. For most applications, that's enough.
Next up: How Mamba's selective state spaces actually work, and why the "selectivity" mechanism is the key innovation.
References
-
Vaswani, A., Shazeer, N., Parmar, N., Uszkoreit, J., Jones, L., Gomez, A. N., Kaiser, Ł., & Polosukhin, I. (2017). Attention Is All You Need. NeurIPS. arXiv:1706.03762
-
Beltagy, I., Peters, M. E., & Cohan, A. (2020). Longformer: The Long-Document Transformer. arXiv preprint. arXiv:2004.05150
-
Zaheer, M., Guruganesh, G., Dubey, A., Ainslie, J., Alberti, C., Ontanon, S., Pham, P., Ravula, A., Wang, Q., Yang, L., & Ahmed, A. (2020). Big Bird: Transformers for Longer Sequences. NeurIPS. arXiv:2007.14062
-
Kitaev, N., Kaiser, Ł., & Levskaya, A. (2020). Reformer: The Efficient Transformer. ICLR. arXiv:2001.04451
-
Katharopoulos, A., Vyas, A., Pappas, N., & Fleuret, F. (2020). Transformers are RNNs: Fast Autoregressive Transformers with Linear Attention. ICML. arXiv:2006.16236
-
Choromanski, K., Likhosherstov, V., Dohan, D., Song, X., Gane, A., Sarlos, T., Hawkins, P., Davis, J., Mohiuddin, A., Kaiser, L., Belanger, D., Colwell, L., & Weller, A. (2021). Rethinking Attention with Performers. ICLR. arXiv:2009.14794
-
Qin, Z., Sun, W., Deng, H., Li, D., Wei, Y., Lv, B., Yan, J., Kong, L., & Zhong, Y. (2022). cosFormer: Rethinking Softmax in Attention. ICLR. arXiv:2202.08791
-
Wang, S., Li, B., Khabsa, M., Fang, H., & Ma, H. (2020). Linformer: Self-Attention with Linear Complexity. arXiv preprint. arXiv:2006.04768
-
Xiong, Y., Zeng, Z., Chakraborty, R., Tan, M., Fung, G., Li, Y., & Singh, V. (2021). Nyströmformer: A Nyström-Based Algorithm for Approximating Self-Attention. AAAI. arXiv:2102.03902
-
Gu, A., Goel, K., & Ré, C. (2022). Efficiently Modeling Long Sequences with Structured State Spaces. ICLR. arXiv:2111.00396
-
Gu, A., & Dao, T. (2023). Mamba: Linear-Time Sequence Modeling with Selective State Spaces. arXiv preprint. arXiv:2312.00752
-
Dao, T., Fu, D. Y., Ermon, S., Rudra, A., & Ré, C. (2022). FlashAttention: Fast and Memory-Efficient Exact Attention with IO-Awareness. NeurIPS. arXiv:2205.14135
-
Dao, T. (2023). FlashAttention-2: Faster Attention with Better Parallelism and Work Partitioning. arXiv preprint. arXiv:2307.08691
-
Peng, B., Alcaide, E., Anthony, Q., Albalak, A., Arcadinho, S., Biderman, S., Cao, H., Cheng, X., Chung, M., et al. (2023). RWKV: Reinventing RNNs for the Transformer Era. EMNLP. arXiv:2305.13048
-
Lieber, O., Lenz, B., Ratner, E., et al. (2024). Jamba: A Hybrid Transformer-Mamba Language Model. arXiv preprint. arXiv:2403.19887
-
Jiang, A. Q., Sablayrolles, A., Roux, A., et al. (2024). Mixtral of Experts. arXiv preprint. arXiv:2401.04088
-
Ainslie, J., Lee-Thorp, J., de Jong, M., Zemlyanskiy, Y., Lebrón, F., & Sanghai, S. (2023). GQA: Training Generalized Multi-Query Transformer Models from Multi-Head Checkpoints. EMNLP. arXiv:2305.13245
-
Shazeer, N. (2019). Fast Transformer Decoding: One Write-Head is All You Need. arXiv preprint. arXiv:1911.02150