Transformer Architecture
The transformer (Vaswani et al., 2017) is the dominant architecture for language foundation models. It replaced RNNs by using the attention mechanism to process all input tokens in parallel. Every LLM you use today (GPT, Claude, Gemini, Llama) is transformer-based.
Transformer Architecture
TL;DR
The transformer (Vaswani et al., 2017) is the dominant architecture for language foundation models. It replaced RNNs by using the attention mechanism to process all input tokens in parallel. Every LLM you use today (GPT, Claude, Gemini, Llama) is transformer-based.
The historical problem
Before 2017, sequence-to-sequence tasks (translation, summarization) used seq2seq with RNNs (Sutskever et al., 2014). Google put seq2seq into Google Translate in 2016 and called it their biggest translation quality jump ever.
Seq2seq had two fatal problems:
- Information bottleneck: the decoder saw only the final hidden state of the encoder. Like trying to summarize a book from a single paragraph summary. Detail was lost.
- Sequential processing: RNNs process tokens one by one. A 200-token input takes 200 sequential steps. You cannot parallelize on GPU. Training is slow.
The transformer fixed both with attention.
How it works
The shift: parallel processing via attention
The transformer dispensed with RNNs entirely. It uses attention to let each output token look at ALL input tokens directly, weighted by relevance. No bottleneck. And input tokens can be processed in parallel because there is no recurrence.
The two inference phases
Every transformer language model runs inference in two phases (critical to understand for kv cache):
- Prefill: processes the input prompt. All input tokens attended to in parallel. Builds the key (K) and value (V) vectors for every input token. Compute-bound. Expensive but parallelizable.
- Decode: generates tokens one by one, autoregressively. Each new token attends over all previous K/V vectors. Memory-bound. Slow but small compute per step.
The K and V vectors from prefill are what the KV cache stores, so decode does not recompute them.
Attention mechanism (the heart of it)
The attention function uses three vectors:
- Query (Q): what the current token is looking for. "I am generating token 5, what context do I need?"
- Key (K): what each previous token advertises. "I am token 2, here is what I am about."
- Value (V): what each previous token contains. "Here is my actual meaning."
Analogy: Q is a person skimming a book, K is each page's title, V is the page's content. The person picks the most relevant pages (high Q-K dot product) and reads their content (V).
Compute:
- Given input
x, you multiply by learned weight matrices W_K, W_V, W_Q to get K, V, Q - Attention score =
softmax(Q K^T / sqrt(d)) V - The
sqrt(d)stabilizes gradients (d = dimension per head)
Multi-head attention
Attention is almost always multi-headed. Llama 2-7B has 32 heads. Each head has its own Q, K, V projection and learns to attend to different things (syntax, entities, long-range, etc.). Outputs are concatenated then projected back to the model's hidden dimension.
Dimensions example for Llama 2-7B:
- Hidden dim = 4096
- 32 attention heads
- Each head: 4096 / 32 = 128 per Q/K/V
Transformer block
A transformer is a stack of identical transformer blocks. Each block contains:
- Attention module (Q, K, V, output projection matrices)
- MLP module (feed-forward network, usually 4x hidden dim wide)
- Residual connections and layer normalization around both
Llama 2-7B has 32 transformer blocks. GPT-4 has rumored 96+. The number of blocks is often called the model's depth.
Total parameters mostly come from attention (Q, K, V, O matrices) and MLP (two big matrices). The ratio depends on architecture.
Relevance today (2026)
Huyen's 2024 explanation of transformers is still accurate. What has evolved:
- Multi-query and grouped-query attention (MQA/GQA) are now standard for inference efficiency. All heads share one K/V projection (MQA) or groups of heads share (GQA, used in Llama 3+). Cuts KV cache memory by 8x+ without quality loss.
- Rotary Position Embedding (RoPE) replaced absolute positional encodings in most new models. Better long-context generalization.
- FlashAttention 2 and 3 (2023-2024) rewrote the attention kernel for GPU memory efficiency. Standard in every serving stack now.
- Alternative architectures are real but niche.
- Mamba 2 and state-space models: O(N) inference, good for very long context, not yet beating transformers on general tasks
- Hybrid models (Jamba, Zamba): mix transformer and state-space layers, good cost-quality trade-off
- RWKV: RNN-with-transformer-tricks, community-driven
- In 2026, transformers still win on quality per FLOP for most tasks
- Sliding window attention (Mistral, Llama 3) limits the attention to a fixed window (e.g., last 4K tokens) to control KV cache growth.
- Context length exploded. 2024 frontier: 128K. 2026 frontier: 1M-10M tokens (Gemini 2 Ultra, Claude Opus 4.x with research preview). Making attention scale to 1M+ needed new tricks (Ring Attention, chunk attention, hierarchical retrieval inside attention).
Question to keep in mind: is the transformer permanent or just entrenched? Every year someone predicts its death. Every year it still wins the production benchmarks. Bet on transformers for now, watch state-space as the most credible challenger.
Critical questions
- Why does the
sqrt(d)in attention score matter? What breaks if you remove it? - If attention is parallelizable for input, why is generation still sequential? What would it take to fix that?
- Multi-query attention degrades quality slightly compared to full multi-head. Why do all new models use MQA/GQA anyway?
- Could a model with zero attention layers (pure MLP + some trick) work? (Spoiler: researchers have tried, results are mediocre so far.)
- Why 32 attention heads and not 64 or 16? What drives that choice?
Production pitfalls
- Naive attention at long context is O(N^2) memory. Without FlashAttention or similar, a 32K-context request OOMs on a single GPU. Always use modern serving frameworks (vLLM, TGI, SGLang).
- KV cache eats VRAM. For Llama 2-7B at FP16 with 32K context, KV cache alone is gigabytes. Do the math before deploying.
- Mismatch between position encoding in training and inference. Some models (older Llama) trained on 4K context generate garbage past 4K unless you use RoPE scaling (NTK-aware, YaRN). Check the model card.
- MQA/GQA quality degradation varies. Some fine-tunes lose more than others. Test on your specific eval set before switching.
- Numerical precision issues. Attention in FP16 can overflow on long sequences. Modern kernels use BF16 or mixed precision.
Alternatives / Comparisons
| Architecture | What it brings | Status in 2026 |
|---|---|---|
| Transformer (vanilla) | Strong quality, parallel training | Dominant |
| Transformer + MQA/GQA | Better inference efficiency | Default for new models |
| Transformer + sliding window | Controlled KV cache growth | Mistral/Llama 3, common |
| Mamba 2 | O(N) inference, very long context | Research, small niches |
| Jamba (hybrid) | Transformer + Mamba, best of both | Early production |
| RWKV | RNN-style, constant-memory inference | Community, small |
| Pure MLP (gMLP, MLP-Mixer) | Simplicity | Failed for sequences |
Mini-lab
labs/attention-from-scratch/ (to create) - implement a single attention head in ~50 lines of NumPy or PyTorch. Compute Q, K, V for a toy input, run the attention formula, visualize the attention weights as a heatmap.
Suggested starter: Karpathy's nanoGPT attention implementation. Focus on understanding the shapes and the softmax.
Further reading
- "Attention Is All You Need" (Vaswani et al., 2017) - the original paper
- Karpathy, "Let's build GPT from scratch" (YouTube) - step by step build
- Lilian Weng, "The Transformer Family" (lilianweng.github.io) - best blog post for variants
- FlashAttention paper (Dao et al., 2022) and FlashAttention-2 (2023) - memory efficiency
- Huyen, Chapter 2 - the version this notion extends
- Gu and Dao, "Mamba: Linear-Time Sequence Modeling" (2023) - leading alternative