17 min read

Attention That You Probably Didnt Know Existed!!

From sparse patterns to linear attention and state space models exploring the zoo of efficient attention mechanisms that go beyond vanilla transformers.

TransformersAttentionEfficient AILong Context
Table of Contents

I've been looking at self attention since 2019, since then a lof of variants of self-attention has been proposed. In this article, I breakdown efficient attention mechanisms: linear attention, sparse patterns, low-rank approximations, grouped attention and a lot of other attention as well. Time to pay attention to self-attention xD.

The Quadratic Bottleneck

Standard self-attention computes:

Attention(Q,K,V)=softmax(QKTdk)V\text{Attention}(Q, K, V) = \text{softmax}\left(\frac{QK^T}{\sqrt{d_k}}\right)V

The matrix QKTQK^T has shape [n×n][n \times n]. For a 100K token sequence with 128 attention heads across 80 layers, you're looking at:

Memory=1000002×128×80×2 bytes20 TB\text{Memory} = 100000^2 \times 128 \times 80 \times 2 \text{ bytes} \approx 20 \text{ TB}

Obviously impossible. Even at 4K tokens, this dominates memory usage. The challenge: can we approximate attention without materializing the full n×nn \times n matrix?

Approach 1: Sparse Attention Patterns

The simplest idea: don't compute all n2n^2 attention scores. Only attend to a subset of positions.

Fixed Patterns (Longformer, BigBird)

Define a fixed sparsity pattern combining:

  • Local attention: Each token attends to ww neighbors
  • Global attention: Special tokens attend to everything
  • Random attention: Sparse random connections
sparse_attention_pattern.py
def create_longformer_mask(seq_len: int, window_size: int, global_tokens: int):
    """
    Creates Longformer-style attention mask
    Returns: [seq_len, seq_len] boolean mask
    """
    mask = torch.zeros(seq_len, seq_len, dtype=torch.bool)
    
    # Local sliding window
    for i in range(seq_len):
        start = max(0, i - window_size // 2)
        end = min(seq_len, i + window_size // 2 + 1)
        mask[i, start:end] = True
    
    # Global tokens (first k tokens attend to/from everything)
    mask[:global_tokens, :] = True
    mask[:, :global_tokens] = True
    
    return mask

Complexity: O(nw)O(n \cdot w) where ww is window size—linear in sequence length!

Attention patterns: Full n² attention, Sliding window, Dilated sliding window, Global+sliding windowAttention patterns: Full n² attention, Sliding window, Dilated sliding window, Global+sliding window

The catch: Fixed patterns assume locality matters. For tasks requiring long-range reasoning, you're hoping information propagates through the sparse graph. Sometimes it doesn't.

Learned Sparse Patterns (Routing Transformer, Reformer)

Let the model learn which tokens to attend to:

Reformer uses Locality-Sensitive Hashing (LSH):

h(x)=argmax([xR;xR])h(x) = \text{argmax}([xR; -xR])

Tokens with the same hash are grouped into buckets and only attend within buckets.

lsh_attention.py
def lsh_attention(Q, K, V, num_hashes: int = 4, num_buckets: int = 32):
    """Simplified LSH attention (conceptual)"""
    batch, seq_len, d = Q.shape
    
    # Generate random projection matrix
    R = torch.randn(d, num_buckets // 2)
    
    # Hash queries and keys
    projections = Q @ R  # [batch, seq_len, num_buckets//2]
    hashes = torch.argmax(
        torch.cat([projections, -projections], dim=-1), dim=-1
    )
    
    # Sort by hash to group similar tokens
    sorted_indices = torch.argsort(hashes, dim=1)
    Q_sorted = torch.gather(Q, 1, sorted_indices.unsqueeze(-1).expand_as(Q))
    K_sorted = torch.gather(K, 1, sorted_indices.unsqueeze(-1).expand_as(K))
    V_sorted = torch.gather(V, 1, sorted_indices.unsqueeze(-1).expand_as(V))
    
    # Attend within chunks (buckets)
    chunk_size = seq_len // num_buckets
    # ... chunked attention within each bucket
    
    return output

Complexity: O(nlogn)O(n \log n) for LSH bucketing + O(nb)O(n \cdot b) for within-bucket attention.

The catch: Hash collisions are probabilistic. Similar tokens might end up in different buckets. You need multiple hash rounds to reduce this, eating into efficiency gains.

Approach 2: Linear Attention (Kernel Methods)

The most mathematically elegant approach: reformulate attention to avoid the explicit n×nn \times n matrix.

The Kernel Trick

Standard softmax attention:

Attentioni=jexp(qiTkj)vjjexp(qiTkj)\text{Attention}_i = \frac{\sum_j \exp(q_i^T k_j) v_j}{\sum_j \exp(q_i^T k_j)}

Replace the exponential kernel with a decomposable feature map ϕ\phi:

exp(qTk)ϕ(q)Tϕ(k)\exp(q^T k) \approx \phi(q)^T \phi(k)

Now attention becomes:

Attentioni=ϕ(qi)Tjϕ(kj)vjTϕ(qi)Tjϕ(kj)\text{Attention}_i = \frac{\phi(q_i)^T \sum_j \phi(k_j) v_j^T}{\phi(q_i)^T \sum_j \phi(k_j)}

The key insight: jϕ(kj)vjT\sum_j \phi(k_j) v_j^T and jϕ(kj)\sum_j \phi(k_j) can be computed once and reused for all queries!

linear_attention.py
def linear_attention(Q, K, V, feature_map=None):
    """
    Linear attention via kernel feature maps
    Complexity: O(n * d^2) instead of O(n^2 * d)
    """
    if feature_map is None:
        # ELU-based feature map (from "Transformers are RNNs")
        feature_map = lambda x: F.elu(x) + 1
    
    Q = feature_map(Q)  # [batch, seq, d]
    K = feature_map(K)  # [batch, seq, d]
    
    # Compute KV summary: [batch, d, d_v]
    KV = torch.einsum('bnd,bnv->bdv', K, V)
    
    # Compute K summary for normalization: [batch, d]
    K_sum = K.sum(dim=1)
    
    # Compute attention output
    numerator = torch.einsum('bnd,bdv->bnv', Q, KV)
    denominator = torch.einsum('bnd,bd->bn', Q, K_sum).unsqueeze(-1)
    
    return numerator / (denominator + 1e-6)

Complexity: O(nd2)O(n \cdot d^2)—linear in sequence length!

Popular Linear Attention Variants

MethodFeature Map ϕ\phiNotes
Linear Transformerelu(x)+1\text{elu}(x) + 1Simple, stable
PerformerRandom Fourier FeaturesUnbiased softmax approximation
cosFormerReLU(x)cos(πi2n)\text{ReLU}(x) \cdot \cos(\frac{\pi i}{2n})Re-weighting for locality
RWKVExponential decayRNN-like, state-based

Performer: Random Feature Maps

Performer provides an unbiased estimator of softmax attention using random Fourier features:

exp(qTk)=EωN[exp(ωTqq22)exp(ωTkk22)]\exp(q^T k) = \mathbb{E}_{\omega \sim \mathcal{N}}\left[\exp\left(\omega^T q - \frac{\|q\|^2}{2}\right) \exp\left(\omega^T k - \frac{\|k\|^2}{2}\right)\right]
performer_features.py
def random_fourier_features(x, num_features: int = 256):
    """
    FAVOR+ feature map from Performer
    """
    d = x.shape[-1]
    
    # Sample random projections (orthogonal for lower variance)
    W = torch.randn(num_features, d) / (d ** 0.5)
    W = torch.linalg.qr(W.T)[0].T[:num_features]  # Orthogonalize
    
    # Compute features
    proj = x @ W.T  # [batch, seq, num_features]
    
    # Normalize for unbiased estimate
    norm_factor = (x ** 2).sum(dim=-1, keepdim=True) / 2
    
    features = torch.exp(proj - norm_factor) / (num_features ** 0.5)
    return features

The catch with linear attention: It's not actually approximating softmax well for all distributions. The "softmax" property of sharp, selective attention is lost. For many tasks, quality degrades significantly.

Approach 3: Low-Rank Approximation (Linformer, Nyströmformer)

If the attention matrix has low intrinsic rank, project it to a smaller space.

Linformer

Project keys and values to a fixed dimension knk \ll n:

K~=EK,V~=FV\tilde{K} = E \cdot K, \quad \tilde{V} = F \cdot V

where E,FRk×nE, F \in \mathbb{R}^{k \times n} are learned projection matrices.

linformer.py
class LinformerAttention(nn.Module):
    def __init__(self, d_model: int, seq_len: int, k: int = 256):
        super().__init__()
        self.k = k
        
        # Projection matrices
        self.E = nn.Parameter(torch.randn(k, seq_len) / seq_len ** 0.5)
        self.F = nn.Parameter(torch.randn(k, seq_len) / seq_len ** 0.5)
        
        self.W_q = nn.Linear(d_model, d_model)
        self.W_k = nn.Linear(d_model, d_model)
        self.W_v = nn.Linear(d_model, d_model)
    
    def forward(self, x):
        Q = self.W_q(x)  # [batch, n, d]
        K = self.W_k(x)  # [batch, n, d]
        V = self.W_v(x)  # [batch, n, d]
        
        # Project K and V: [batch, k, d]
        K_proj = torch.einsum('kn,bnd->bkd', self.E, K)
        V_proj = torch.einsum('kn,bnd->bkd', self.F, V)
        
        # Standard attention on reduced dimensions
        # [batch, n, d] @ [batch, d, k] = [batch, n, k]
        scores = torch.matmul(Q, K_proj.transpose(-2, -1)) / (Q.shape[-1] ** 0.5)
        attn = F.softmax(scores, dim=-1)
        
        # [batch, n, k] @ [batch, k, d] = [batch, n, d]
        return torch.matmul(attn, V_proj)

Complexity: O(nk)O(n \cdot k) where kk is the projection dimension.

The catch: Fixed projection dimension means you're betting the attention matrix is always low-rank. This holds for some tasks but not others. The projection also couples sequence length to architecture—you can't easily handle variable lengths.

Nyströmformer

Uses the Nyström method to approximate the full attention matrix from a subset of "landmark" tokens:

A~=softmax(QK~T)softmax(K~K~T)1softmax(K~KT)\tilde{A} = \text{softmax}(Q \tilde{K}^T) \cdot \text{softmax}(\tilde{K} \tilde{K}^T)^{-1} \cdot \text{softmax}(\tilde{K} K^T)

where K~\tilde{K} are landmark keys selected via pooling.

Approach 4: State Space Models (Mamba, S4)

The radical approach: abandon attention entirely. Replace it with structured state space models that process sequences in O(n)O(n) time.

The State Space Formulation

A continuous-time state space model:

dxdt=Ax(t)+Bu(t),y(t)=Cx(t)+Du(t)\frac{dx}{dt} = Ax(t) + Bu(t), \quad y(t) = Cx(t) + Du(t)

Discretize and unroll:

xk=Aˉxk1+Bˉuk,yk=Cxkx_k = \bar{A}x_{k-1} + \bar{B}u_k, \quad y_k = Cx_k

This is an RNN! But with special structure on AA that enables:

  1. Parallel training via convolution
  2. Linear-time inference via recurrence

Mamba: Selective State Spaces

Mamba makes the state space parameters input-dependent:

Bt=Linear(xt),Ct=Linear(xt),Δt=softplus(Linear(xt))B_t = \text{Linear}(x_t), \quad C_t = \text{Linear}(x_t), \quad \Delta_t = \text{softplus}(\text{Linear}(x_t))

This selectivity lets the model dynamically decide what to remember and forget—recovering some of the "content-based" addressing that attention provides.

mamba_block.py
class SimplifiedMambaBlock(nn.Module):
    """Conceptual Mamba block (actual impl uses CUDA kernels)"""
    
    def __init__(self, d_model: int, d_state: int = 16, d_conv: int = 4):
        super().__init__()
        self.d_model = d_model
        self.d_state = d_state
        
        # SSM parameters
        self.A_log = nn.Parameter(torch.randn(d_model, d_state))
        
        # Input-dependent projections
        self.x_proj = nn.Linear(d_model, d_state * 2 + 1)  # B, C, delta
        
        # Convolution for local context
        self.conv = nn.Conv1d(d_model, d_model, d_conv, padding=d_conv-1, groups=d_model)
        
    def forward(self, x):
        batch, seq_len, d = x.shape
        
        # Local convolution
        x_conv = self.conv(x.transpose(1, 2))[:, :, :seq_len].transpose(1, 2)
        x_conv = F.silu(x_conv)
        
        # Compute input-dependent SSM parameters
        x_dbl = self.x_proj(x_conv)
        B = x_dbl[..., :self.d_state]
        C = x_dbl[..., self.d_state:2*self.d_state]
        delta = F.softplus(x_dbl[..., -1:])
        
        # Discretize A
        A = -torch.exp(self.A_log)
        A_bar = torch.exp(delta * A)
        
        # Recurrent scan (parallel via associative scan in practice)
        h = torch.zeros(batch, d, self.d_state, device=x.device)
        outputs = []
        for t in range(seq_len):
            h = A_bar[:, t:t+1] * h + delta[:, t:t+1] * B[:, t:t+1].unsqueeze(1) * x_conv[:, t:t+1].unsqueeze(-1)
            y_t = (h * C[:, t:t+1].unsqueeze(1)).sum(-1)
            outputs.append(y_t)
        
        return torch.cat(outputs, dim=1)

Complexity: O(n)O(n) time and memory!

The catch: SSMs trade the quadratic cost for a fixed-size hidden state. Long-range dependencies must be compressed into this state. For tasks requiring precise retrieval over long contexts (like "find the needle in the haystack"), SSMs struggle compared to attention.

Approach 5: IO-Aware Algorithms (FlashAttention)

Plot twist: maybe the algorithm isn't the problem—the implementation is.

The Memory Hierarchy Insight

Standard attention is memory-bound, not compute-bound. The O(n2)O(n^2) attention matrix gets written to slow HBM (GPU main memory), then read back for softmax and output computation.

FlashAttention keeps everything in fast SRAM by:

  1. Tiling the computation into blocks
  2. Recomputing during backward pass instead of storing
  3. Fusing operations to minimize memory transfers
Standard Attention Memory Access:
Q, K, V (HBM) → Compute QK^T → Store S (HBM) → Load S → Softmax → Store P (HBM) → Load P, V → Output
 
FlashAttention Memory Access:  
Q, K, V (HBM) → [Tiled compute entirely in SRAM] → Output (HBM)

Complexity: Still O(n2)O(n^2) compute, but O(n)O(n) memory and 2-4x faster due to reduced memory I/O.

The insight: FlashAttention doesn't reduce computational complexity—it makes the quadratic algorithm practical by respecting the memory hierarchy. For many use cases, this is enough.

Approach 6: Grouped-Query Attention (GQA)

Here's an approach you might not have heard about that's quietly powering most modern LLMs: what if the bottleneck isn't computing attention, but storing the key-value cache?

The KV Cache Problem

During autoregressive generation, we cache key and value tensors from previous tokens to avoid recomputation. With standard Multi-Head Attention (MHA), each head has its own K and V projections:

KV Cache Size=2×nlayers×nheads×dhead×seq_len\text{KV Cache Size} = 2 \times n_{layers} \times n_{heads} \times d_{head} \times \text{seq\_len}

For a 70B model with 80 layers and 64 heads, generating 8K tokens requires ~40GB just for KV cache. That's often more memory than the model weights themselves!

Multi-Query Attention (MQA)

The radical simplification: all query heads share a single set of keys and values.

MQA: QRn×h×d,K,VRn×1×d\text{MQA: } Q \in \mathbb{R}^{n \times h \times d}, \quad K, V \in \mathbb{R}^{n \times 1 \times d}

Each query head attends using the same K and V:

Attentioni=softmax(QiKTdk)V\text{Attention}_i = \text{softmax}\left(\frac{Q_i K^T}{\sqrt{d_k}}\right)V

KV cache reduction: h×h \times smaller (e.g., 64× for 64-head models)!

The catch: Quality degrades. Different heads should learn different attention patterns, but they're forced to share the same keys and values.

Grouped-Query Attention (GQA): The Sweet Spot

GQA is the Goldilocks solution: instead of 1 KV set (MQA) or hh KV sets (MHA), use gg groups where 1<g<h1 < g < h.

GQA: QRn×h×d,K,VRn×g×d\text{GQA: } Q \in \mathbb{R}^{n \times h \times d}, \quad K, V \in \mathbb{R}^{n \times g \times d}

Each group of h/gh/g query heads shares one set of keys and values:

Attentioni=softmax(QiKig/hTdk)Vig/h\text{Attention}_{i} = \text{softmax}\left(\frac{Q_i K_{\lfloor i \cdot g / h \rfloor}^T}{\sqrt{d_k}}\right)V_{\lfloor i \cdot g / h \rfloor}

Multi-Head, Grouped-Query, and Multi-Query Attention: Trading KV cache size for qualityMulti-Head, Grouped-Query, and Multi-Query Attention: Trading KV cache size for quality

grouped_query_attention.py
class GroupedQueryAttention(nn.Module):
    def __init__(self, d_model: int, n_heads: int, n_kv_heads: int):
        """
        n_heads: number of query heads
        n_kv_heads: number of key-value heads (groups)
        """
        super().__init__()
        self.n_heads = n_heads
        self.n_kv_heads = n_kv_heads
        self.n_rep = n_heads // n_kv_heads  # queries per KV group
        self.head_dim = d_model // n_heads
        
        self.W_q = nn.Linear(d_model, n_heads * self.head_dim)
        self.W_k = nn.Linear(d_model, n_kv_heads * self.head_dim)
        self.W_v = nn.Linear(d_model, n_kv_heads * self.head_dim)
        self.W_o = nn.Linear(d_model, d_model)
    
    def forward(self, x):
        batch, seq_len, _ = x.shape
        
        # Project
        Q = self.W_q(x).view(batch, seq_len, self.n_heads, self.head_dim)
        K = self.W_k(x).view(batch, seq_len, self.n_kv_heads, self.head_dim)
        V = self.W_v(x).view(batch, seq_len, self.n_kv_heads, self.head_dim)
        
        # Repeat K, V to match query heads
        # [batch, seq, n_kv_heads, head_dim] -> [batch, seq, n_heads, head_dim]
        K = K.repeat_interleave(self.n_rep, dim=2)
        V = V.repeat_interleave(self.n_rep, dim=2)
        
        # Standard attention
        Q, K, V = Q.transpose(1, 2), K.transpose(1, 2), V.transpose(1, 2)
        scores = torch.matmul(Q, K.transpose(-2, -1)) / (self.head_dim ** 0.5)
        attn = F.softmax(scores, dim=-1)
        out = torch.matmul(attn, V)
        
        out = out.transpose(1, 2).reshape(batch, seq_len, -1)
        return self.W_o(out)

The Math: Memory Savings

Attention TypeKV HeadsKV Cache SizeQuality
Multi-Head (MHA)hh2Lhdn2 \cdot L \cdot h \cdot d \cdot n100% (baseline)
Grouped-Query (GQA)gg2Lgdn2 \cdot L \cdot g \cdot d \cdot n~99%
Multi-Query (MQA)112L1dn2 \cdot L \cdot 1 \cdot d \cdot n~95-97%

Llama 2 70B uses GQA with 8 KV heads and 64 query heads—an 8× reduction in KV cache with minimal quality loss.

Why this matters: GQA doesn't reduce computational complexity (still O(n2)O(n^2)), but it dramatically reduces memory bandwidth during generation. When you're generating tokens one at a time, memory bandwidth for loading the KV cache is often the bottleneck. GQA makes inference 2-4× faster in practice.

Used by: Llama 2/3, Mistral, Gemma, most modern open-source LLMs.

The Hybrid Future

No single approach dominates. The trend is hybrid architectures:

Sliding Window + Global (Mistral, Gemma 2)

Layer 1-4:   Sliding window attention (local)
Layer 5:    Global attention (every 4th layer)
Layer 6-9:   Sliding window attention (local)
Layer 10:   Global attention
...

This gets O(n)O(n) for most layers while preserving some global context mixing.

Attention + SSM (Jamba, Zamba)

Interleave attention layers with Mamba layers:

  • Mamba layers: Efficient long-range propagation
  • Attention layers: Precise retrieval when needed
hybrid_block.py
class HybridLayer(nn.Module):
    def __init__(self, d_model: int, use_attention: bool):
        super().__init__()
        if use_attention:
            self.mixer = MultiHeadAttention(d_model, num_heads=8)
        else:
            self.mixer = MambaBlock(d_model)
        self.norm = nn.LayerNorm(d_model)
        self.ffn = FeedForward(d_model)
    
    def forward(self, x):
        x = x + self.mixer(self.norm(x))
        x = x + self.ffn(self.norm(x))
        return x
 
# Example: 1 attention layer per 4 Mamba layers
layers = []
for i in range(32):
    use_attention = (i % 4 == 0)
    layers.append(HybridLayer(d_model=2048, use_attention=use_attention))

Comparison: What Actually Works?

MethodComplexityQuality vs SoftmaxTraining StableLong-Range
Full AttentionO(n2)O(n^2)✓ Exact✓ Yes✓ Excellent
FlashAttentionO(n2)O(n^2)✓ Exact✓ Yes✓ Excellent
GQAO(n2)O(n^2)~99%✓ Yes✓ Excellent
Sparse (Longformer)O(nw)O(n \cdot w)~95%✓ Yes~ Depends on pattern
Linear (Performer)O(nd2)O(n \cdot d^2)~85-90%~ Sometimes✗ Degraded
Low-rank (Linformer)O(nk)O(n \cdot k)~90%✓ Yes~ Limited
SSM (Mamba)O(n)O(n)~90-95%✓ Yes~ Good, not perfect
HybridO(n)O(n)~97-99%✓ Yes✓ Very good

The uncomfortable truth: FlashAttention + sliding window hybrids currently win for most practical applications. Pure linear attention and SSMs trade too much quality for theoretical efficiency gains.

Final Thoughts: The Quadratic Wall Is Softer Than It Looks

  1. FlashAttention changed the game: O(n2)O(n^2) compute is fine if memory is O(n)O(n)
  2. Hybrid is the answer: Mix efficient local attention with sparse global attention
  3. SSMs are complements, not replacements: Great for efficiency, but attention still wins for retrieval
  4. Hardware will catch up: Future GPUs with more SRAM might make the whole debate moot

The quest for O(n)O(n) attention isn't over, but the practical ceiling has moved from 4K to 1M+ tokens. For most applications, that's enough.


Next up: How Mamba's selective state spaces actually work, and why the "selectivity" mechanism is the key innovation.


References

  1. Vaswani, A., Shazeer, N., Parmar, N., Uszkoreit, J., Jones, L., Gomez, A. N., Kaiser, Ł., & Polosukhin, I. (2017). Attention Is All You Need. NeurIPS. arXiv:1706.03762

  2. Beltagy, I., Peters, M. E., & Cohan, A. (2020). Longformer: The Long-Document Transformer. arXiv preprint. arXiv:2004.05150

  3. Zaheer, M., Guruganesh, G., Dubey, A., Ainslie, J., Alberti, C., Ontanon, S., Pham, P., Ravula, A., Wang, Q., Yang, L., & Ahmed, A. (2020). Big Bird: Transformers for Longer Sequences. NeurIPS. arXiv:2007.14062

  4. Kitaev, N., Kaiser, Ł., & Levskaya, A. (2020). Reformer: The Efficient Transformer. ICLR. arXiv:2001.04451

  5. Katharopoulos, A., Vyas, A., Pappas, N., & Fleuret, F. (2020). Transformers are RNNs: Fast Autoregressive Transformers with Linear Attention. ICML. arXiv:2006.16236

  6. Choromanski, K., Likhosherstov, V., Dohan, D., Song, X., Gane, A., Sarlos, T., Hawkins, P., Davis, J., Mohiuddin, A., Kaiser, L., Belanger, D., Colwell, L., & Weller, A. (2021). Rethinking Attention with Performers. ICLR. arXiv:2009.14794

  7. Qin, Z., Sun, W., Deng, H., Li, D., Wei, Y., Lv, B., Yan, J., Kong, L., & Zhong, Y. (2022). cosFormer: Rethinking Softmax in Attention. ICLR. arXiv:2202.08791

  8. Wang, S., Li, B., Khabsa, M., Fang, H., & Ma, H. (2020). Linformer: Self-Attention with Linear Complexity. arXiv preprint. arXiv:2006.04768

  9. Xiong, Y., Zeng, Z., Chakraborty, R., Tan, M., Fung, G., Li, Y., & Singh, V. (2021). Nyströmformer: A Nyström-Based Algorithm for Approximating Self-Attention. AAAI. arXiv:2102.03902

  10. Gu, A., Goel, K., & Ré, C. (2022). Efficiently Modeling Long Sequences with Structured State Spaces. ICLR. arXiv:2111.00396

  11. Gu, A., & Dao, T. (2023). Mamba: Linear-Time Sequence Modeling with Selective State Spaces. arXiv preprint. arXiv:2312.00752

  12. Dao, T., Fu, D. Y., Ermon, S., Rudra, A., & Ré, C. (2022). FlashAttention: Fast and Memory-Efficient Exact Attention with IO-Awareness. NeurIPS. arXiv:2205.14135

  13. Dao, T. (2023). FlashAttention-2: Faster Attention with Better Parallelism and Work Partitioning. arXiv preprint. arXiv:2307.08691

  14. Peng, B., Alcaide, E., Anthony, Q., Albalak, A., Arcadinho, S., Biderman, S., Cao, H., Cheng, X., Chung, M., et al. (2023). RWKV: Reinventing RNNs for the Transformer Era. EMNLP. arXiv:2305.13048

  15. Lieber, O., Lenz, B., Ratner, E., et al. (2024). Jamba: A Hybrid Transformer-Mamba Language Model. arXiv preprint. arXiv:2403.19887

  16. Jiang, A. Q., Sablayrolles, A., Roux, A., et al. (2024). Mixtral of Experts. arXiv preprint. arXiv:2401.04088

  17. Ainslie, J., Lee-Thorp, J., de Jong, M., Zemlyanskiy, Y., Lebrón, F., & Sanghai, S. (2023). GQA: Training Generalized Multi-Query Transformer Models from Multi-Head Checkpoints. EMNLP. arXiv:2305.13245

  18. Shazeer, N. (2019). Fast Transformer Decoding: One Write-Head is All You Need. arXiv preprint. arXiv:1911.02150