LLM Architecture & Internals

A comprehensive technical reference covering everything from tokenization to inference optimization

1. High-Level Overview

Large Language Models (LLMs) are neural networks — typically based on the Transformer architecture — trained on massive text corpora to predict the next token in a sequence. Despite the apparent simplicity of this objective, the resulting models exhibit remarkable emergent capabilities: reasoning, translation, code generation, summarisation, and more.

┌─────────────────────────────────────────────────────────────────┐ │ LLM PIPELINE OVERVIEW │ ├─────────────────────────────────────────────────────────────────┤ │ │ │ Raw Text ──► Tokenizer ──► Token IDs ──► Embedding Layer │ │ │ │ │ ┌─────────▼──────────┐ │ │ │ Transformer Block │ │ │ │ ┌───────────────┐ │ │ │ │ │ Attention │ │ × N │ │ │ │ + FFN │ │ │ │ │ │ + LayerNorm │ │ │ │ │ └───────────────┘ │ │ │ └─────────┬──────────┘ │ │ │ │ │ Linear Projection │ │ │ │ │ Softmax │ │ │ │ │ Next-Token Probabilities │ │ │ │ │ Sampling / Argmax │ │ │ │ │ Output Token │ └─────────────────────────────────────────────────────────────────┘

Key Dimensions of an LLM

ParameterSymbolTypical RangeDescription
LayersL24 – 128Number of stacked transformer blocks
Hidden dimensiondmodel2048 – 18432Width of the residual stream
Attention headsh16 – 128Parallel attention patterns per layer
Head dimensiondk64 – 256dmodel / h
FFN dimensiondff4× to 8× dmodelWidth of feed-forward hidden layer
Vocabulary sizeV32K – 256KNumber of unique tokens
Context lengthT2K – 1M+Maximum input sequence length
ParametersN1B – 1T+Total trainable weights

2. Tokenization

Tokenization is the very first step: converting raw text into a sequence of integer IDs that the model can process. The quality of the tokenizer directly affects model efficiency, multilinguality, and even downstream task performance.

Byte-Pair Encoding (BPE)

BPE is the most widely used tokenization algorithm in modern LLMs (GPT family, LLaMA, etc.). It starts with individual bytes or characters and iteratively merges the most frequent adjacent pair into a new token, building a vocabulary bottom-up.

BPE Training Algorithm

1. Initialize vocabulary with all individual bytes (256 tokens)
2. Count all adjacent token pairs in the corpus
3. Find the most frequent pair (e.g., "t" + "h" → "th")
4. Merge all occurrences of that pair into a single new token
5. Add the new token to the vocabulary
6. Repeat steps 2–5 until vocabulary reaches target size (e.g., 50,257)

Example progression:
  "l o w e r" → "lo w e r" → "low e r" → "low er" → "lower"

BPE Encoding (Inference)

Input:  "unhappiness"
Step 1: Split into characters: ["u", "n", "h", "a", "p", "p", "i", "n", "e", "s", "s"]
Step 2: Apply merges in priority order:
        "u" + "n" → "un"
        "h" + "a" → "ha"
        "p" + "p" → "pp"
        "ha" + "pp" → "happ"
        "un" + "happ" → "unhapp"
        "i" + "n" → "in"
        "e" + "ss" → "ess"
        "in" + "ess" → "iness"
Result: ["unhapp", "iness"]  →  Token IDs: [48291, 7274]
Why bytes, not characters? Byte-level BPE (used by GPT-2 onwards) operates on raw UTF-8 bytes instead of Unicode characters. This means the tokenizer can handle any language or script — it will never encounter an unknown character, since every string is representable as bytes.

SentencePiece & Unigram

SentencePiece (used by T5, LLaMA, Gemma) treats the input as a raw byte stream without pre-tokenization, making it truly language-agnostic. It supports two algorithms:

  • BPE mode: Same merging logic as standard BPE but applied without whitespace pre-splitting.
  • Unigram mode: Starts with a large initial vocabulary and iteratively removes tokens that least affect the training loss, keeping those that maximise the likelihood of the corpus under a unigram language model. At inference, it finds the most probable segmentation using the Viterbi algorithm.

Tokenizer Comparison

TokenizerAlgorithmVocab SizeUsed By
tiktoken (cl100k)Byte-level BPE100,256GPT-4, GPT-4o
tiktoken (o200k)Byte-level BPE200,019GPT-4o, o1
SentencePiece BPEBPE32,000LLaMA 1/2
SentencePiece BPEBPE128,256LLaMA 3
SentencePiece UnigramUnigram32,100T5
WordPieceBPE variant30,522BERT

Special Tokens

Every tokenizer defines special tokens that serve structural roles:

  • <|begin_of_text|> / <s> — Beginning of sequence
  • <|end_of_text|> / </s> — End of sequence
  • <|pad|> — Padding for batch alignment
  • <|im_start|> / <|im_end|> — Chat role delimiters
  • [MASK] — Used in masked language modelling (BERT)

3. Embeddings

After tokenization, each token ID is mapped to a dense vector using an embedding lookup table. This table is a learnable matrix WE ∈ ℝV × d where V is the vocabulary size and d is the model's hidden dimension.

# Embedding lookup (PyTorch)
import torch.nn as nn

class TokenEmbedding(nn.Module):
    def __init__(self, vocab_size, d_model):
        super().__init__()
        self.embed = nn.Embedding(vocab_size, d_model)

    def forward(self, token_ids):
        # token_ids: (batch, seq_len) → (batch, seq_len, d_model)
        return self.embed(token_ids)

Positional Encoding

Transformers have no inherent sense of position — the self-attention operation is permutation-equivariant. Positional information must be injected explicitly. Several schemes exist:

Sinusoidal Positional Encoding (Original Transformer)

PE(pos, 2i) = sin(pos / 100002i/d)      PE(pos, 2i+1) = cos(pos / 100002i/d)

Fixed (non-learnable) sinusoidal functions where each dimension oscillates at a different frequency. Rarely used in modern LLMs.

Learned Absolute Positional Embeddings

A separate learnable embedding table WP ∈ ℝT × d is added to the token embeddings. Simple but cannot generalise beyond the trained context length. Used in GPT-2.

Rotary Position Embedding (RoPE)

RoPE is the dominant positional encoding in modern LLMs (LLaMA, Mistral, Qwen, and many others). Instead of adding position information to the input, it rotates the query and key vectors in the attention mechanism as a function of their absolute position.

RoPE(xm, m) = xm ⊙ cos(mθ) + rotate(xm) ⊙ sin(mθ)

Where θi = 10000-2i/d and rotate swaps adjacent pairs and negates alternating elements. The key property: the dot product between two RoPE-encoded vectors depends only on their relative position, not their absolute positions.

# Simplified RoPE implementation
import torch

def apply_rope(x, freqs_cis):
    """x: (batch, seq_len, n_heads, head_dim)"""
    # Reshape to pairs of 2
    x_complex = torch.view_as_complex(x.float().reshape(*x.shape[:-1], -1, 2))
    # Element-wise complex multiplication = rotation
    x_rotated = x_complex * freqs_cis
    return torch.view_as_real(x_rotated).reshape_as(x)

def precompute_freqs(dim, max_seq_len, theta=10000.0):
    freqs = 1.0 / (theta ** (torch.arange(0, dim, 2).float() / dim))
    t = torch.arange(max_seq_len)
    freqs = torch.outer(t, freqs)
    return torch.polar(torch.ones_like(freqs), freqs)  # e^(i*theta)

ALiBi (Attention with Linear Biases)

Instead of modifying embeddings, ALiBi adds a linear bias to attention scores based on the distance between query and key positions. Each head gets a different slope m, so the bias for head h at distance d is -mh · |d|. Used in BLOOM and MPT.

Why RoPE Dominates RoPE offers the best combination of: (a) encoding relative position naturally, (b) generalising to longer sequences via NTK-aware interpolation or YaRN, (c) computational efficiency, and (d) compatibility with KV caching during inference.

4. The Transformer Block

The Transformer block is the fundamental computational unit. A typical LLM stacks L identical blocks (24 to 128 layers), each consisting of a self-attention sub-layer followed by a feed-forward network, with normalization and residual connections.

Input x │ ┌──────▼──────┐ │ LayerNorm │ ◄── Pre-Norm (modern LLMs) └──────┬──────┘ │ ┌──────▼──────┐ │ Multi-Head │ │ Attention │──── Q, K, V projections └──────┬──────┘ │ ┌──────▼──────┐ │ + Residual │ ◄── x + Attention(LN(x)) └──────┬──────┘ │ ┌──────▼──────┐ │ LayerNorm │ └──────┬──────┘ │ ┌──────▼──────┐ │ Feed-Forward │ │ Network │──── SwiGLU / GeGLU └──────┬──────┘ │ ┌──────▼──────┐ │ + Residual │ ◄── x + FFN(LN(x)) └──────┬──────┘ │ Output

Self-Attention Mechanism

Self-attention allows every token in a sequence to attend to every other token, computing a weighted combination of all value vectors based on the similarity between query-key pairs.

Step-by-Step Computation

# Given input X ∈ ℝ^(seq_len × d_model)

# 1. Linear projections
Q = X @ W_Q    # (seq_len, d_model) × (d_model, d_k) → (seq_len, d_k)
K = X @ W_K    # (seq_len, d_model) × (d_model, d_k) → (seq_len, d_k)
V = X @ W_V    # (seq_len, d_model) × (d_model, d_v) → (seq_len, d_v)

# 2. Compute attention scores
scores = Q @ K.T           # (seq_len, seq_len)
scores = scores / sqrt(d_k) # Scale to prevent softmax saturation

# 3. Apply causal mask (for decoder / autoregressive models)
mask = torch.triu(torch.ones(seq_len, seq_len), diagonal=1).bool()
scores.masked_fill_(mask, float('-inf'))

# 4. Softmax to get attention weights
weights = softmax(scores, dim=-1)  # (seq_len, seq_len), rows sum to 1

# 5. Weighted sum of values
output = weights @ V  # (seq_len, d_v)
Attention(Q, K, V) = softmax( QKT / √dk ) V
Computational Cost The QKT matrix multiplication is O(n2d) in both time and O(n2) in memory, where n is the sequence length. This quadratic scaling is the primary bottleneck for long-context models, motivating techniques like FlashAttention, sparse attention, and linear attention approximations.

Multi-Head Attention (MHA)

Rather than performing a single attention function, Multi-Head Attention runs h parallel attention operations ("heads"), each with its own Q, K, V projections at a reduced dimension dk = dmodel / h. The outputs are concatenated and linearly projected.

# Multi-Head Attention
class MultiHeadAttention(nn.Module):
    def __init__(self, d_model, n_heads):
        super().__init__()
        self.n_heads = n_heads
        self.d_k = d_model // n_heads

        self.W_q = nn.Linear(d_model, d_model, bias=False)
        self.W_k = nn.Linear(d_model, d_model, bias=False)
        self.W_v = nn.Linear(d_model, d_model, bias=False)
        self.W_o = nn.Linear(d_model, d_model, bias=False)

    def forward(self, x, mask=None):
        B, T, C = x.shape

        # Project and reshape to (B, n_heads, T, d_k)
        q = self.W_q(x).view(B, T, self.n_heads, self.d_k).transpose(1, 2)
        k = self.W_k(x).view(B, T, self.n_heads, self.d_k).transpose(1, 2)
        v = self.W_v(x).view(B, T, self.n_heads, self.d_k).transpose(1, 2)

        # Scaled dot-product attention per head
        scores = (q @ k.transpose(-2, -1)) / math.sqrt(self.d_k)
        if mask is not None:
            scores.masked_fill_(mask == 0, float('-inf'))
        attn = F.softmax(scores, dim=-1)
        out = attn @ v  # (B, n_heads, T, d_k)

        # Concatenate heads and project
        out = out.transpose(1, 2).contiguous().view(B, T, C)
        return self.W_o(out)

Each head can learn different attention patterns — some heads attend to nearby tokens (local patterns), some to syntactic structures, some to semantic relationships, and some to positional patterns.

Multi-Query Attention (MQA) & Grouped-Query Attention (GQA)

Standard MHA requires storing separate K and V tensors for every head in the KV cache during inference, which becomes a major memory bottleneck. Two variants reduce this cost:

VariantK/V HeadsMemory SavingsUsed By
MHA (standard)h (same as Q)BaselineGPT-3, BERT
MQA1 (shared)~h× reductionPaLM, Falcon
GQAg groups (1 < g < h)~h/g× reductionLLaMA 2/3, Mistral, Gemma
Multi-Head Attention (MHA) Grouped-Query Attention (GQA) Multi-Query Attention (MQA) ┌───┬───┬───┬───┬───┬───┬───┬───┐ ┌───┬───┬───┬───┬───┬───┬───┬───┐ ┌───┬───┬───┬───┬───┬───┬───┬───┐ Q: │ 1 │ 2 │ 3 │ 4 │ 5 │ 6 │ 7 │ 8 │ │ 1 │ 2 │ 3 │ 4 │ 5 │ 6 │ 7 │ 8 │ │ 1 │ 2 │ 3 │ 4 │ 5 │ 6 │ 7 │ 8 │ └─┬─┴─┬─┴─┬─┴─┬─┴─┬─┴─┬─┴─┬─┴─┬─┘ └─┬─┴─┬─┴─┬─┴─┬─┴─┬─┴─┬─┴─┬─┴─┘ └─┬─┴─┬─┴─┬─┴─┬─┴─┬─┴─┬─┴─┬─┴─┘ K: │ 1 │ 2 │ 3 │ 4 │ 5 │ 6 │ 7 │ 8 │ │ 1 │ 2 │ 3 │ 4 │ │ 1 │ └───┴───┴───┴───┴───┴───┴───┴───┘ └─────┴─────┴─────┴─────┘ └───────────────┘ 8 unique K/V per layer 4 K/V groups (g=4) 1 shared K/V

Feed-Forward Network (FFN)

After the attention sub-layer, each token's representation passes through a position-wise feed-forward network. The standard FFN is a two-layer MLP that projects up to a higher dimension, applies a nonlinearity, and projects back down.

Standard FFN

FFN(x) = W2 · σ(W1x + b1) + b2

Where W1 ∈ ℝd×4d and W2 ∈ ℝ4d×d. The expansion ratio is typically 4×.

SwiGLU (Used in LLaMA, Mistral, Gemma)

SwiGLU replaces the standard ReLU activation with a gated variant using the SiLU (Swish) activation function. It uses three weight matrices instead of two:

SwiGLU(x) = W2 · ( SiLU(Wgatex) ⊙ (Wupx) )
class SwiGLU_FFN(nn.Module):
    def __init__(self, d_model, d_ff):
        super().__init__()
        self.w_gate = nn.Linear(d_model, d_ff, bias=False)
        self.w_up   = nn.Linear(d_model, d_ff, bias=False)
        self.w_down  = nn.Linear(d_ff, d_model, bias=False)

    def forward(self, x):
        return self.w_down(F.silu(self.w_gate(x)) * self.w_up(x))
FFN as Key-Value Memory Research suggests that FFN layers function as massive key-value stores. The first projection (W1) computes "keys" that match input patterns, and the second projection (W2) retrieves associated "values" (knowledge). This is why FFN layers store much of the model's factual knowledge.

Normalization

Normalization stabilises training by controlling the scale of activations. Two main approaches:

Layer Normalization (LayerNorm)

LayerNorm(x) = γ ⊙ (x - μ) / √(σ² + ε) + β

Normalises across the feature dimension for each token independently. μ and σ are the mean and standard deviation computed over dmodel dimensions.

RMSNorm (Root Mean Square Normalization)

RMSNorm(x) = γ ⊙ x / √( mean(x²) + ε )

A simplified variant that skips mean-centering (no subtraction of μ, no β bias). Computationally cheaper and equally effective. Used in LLaMA, Mistral, Gemma, and most modern architectures.

Pre-Norm vs. Post-Norm

The original Transformer applied normalization after the residual addition ("Post-Norm"). Modern LLMs universally use "Pre-Norm" — normalizing the input to each sub-layer before passing it through attention/FFN. Pre-Norm provides more stable gradients during training, enabling training at larger scales without careful learning rate scheduling.

Residual Connections

Every sub-layer (attention and FFN) is wrapped in a residual (skip) connection: output = x + SubLayer(Norm(x)). This creates a direct gradient highway from the output layer all the way back to the input, preventing vanishing gradients in very deep networks. The residual stream can be thought of as a shared "communication bus" that all layers read from and write to.

The Residual Stream Interpretation A powerful way to understand Transformers: the residual stream is the primary data structure. Each attention head and FFN reads from this stream, computes an update, and adds it back. The final output is the sum of the original embedding plus all layer contributions. This perspective underlies much of modern mechanistic interpretability research.

5. Architecture Variants

Decoder-Only (Causal / Autoregressive)

The dominant architecture for modern LLMs. Uses causal masking so each token can only attend to itself and all preceding tokens. This is natural for text generation — the model predicts the next token given all previous tokens.

  • Examples: GPT family, LLaMA, Mistral, Claude, Gemini, Falcon, Phi, Qwen
  • Training objective: Next-token prediction (causal language modelling)
  • Advantages: Simple, scales well, naturally supports generation, can be used for any NLP task via prompting

Encoder-Decoder

Uses a bidirectional encoder (no causal mask — every token sees every other token) to process the input, and an autoregressive decoder with cross-attention to the encoder representations to generate the output.

  • Examples: T5, BART, UL2, Flan-T5
  • Training objective: Span corruption (mask spans in input, predict them in output)
  • Advantages: Naturally suited for sequence-to-sequence tasks (translation, summarisation)

Mixture of Experts (MoE)

MoE replaces the dense FFN in each Transformer block with a set of N "expert" FFN sub-networks and a lightweight gating/routing mechanism. For each token, only the top-K experts (typically K=2) are activated, keeping the computational cost per token manageable even as the total parameter count grows massively.

Input token x │ ┌──────▼──────┐ │ Router │ │ (Linear │ │ + Softmax)│ └──────┬──────┘ │ ┌────────┬───────┼───────┬────────┐ │ │ │ │ │ ┌───▼──┐ ┌───▼──┐ ┌─▼────┐ ┌▼─────┐ ┌▼─────┐ │ E_1 │ │ E_2 │ │ E_3 │ │ E_4 │ │...E_N│ │(FFN) │ │(FFN) │ │(FFN) │ │(FFN) │ │(FFN) │ └───┬──┘ └───┬──┘ └──┬───┘ └──┬───┘ └──┬───┘ │ │ │ │ │ └────────┴───┬───┴────────┘ │ (only top-K are computed) │ │ │ Weighted sum of selected expert outputs │ Output

Router / Gating Mechanism

# Simplified top-K routing
gate_logits = x @ W_gate        # (batch*seq, n_experts)
gate_scores = softmax(gate_logits)
top_k_scores, top_k_indices = gate_scores.topk(K)
top_k_scores = top_k_scores / top_k_scores.sum(dim=-1, keepdim=True)  # renormalize

# Dispatch tokens to selected experts, compute, and combine
output = sum(score_i * Expert_i(x) for i, score_i in zip(top_k_indices, top_k_scores))

Load Balancing

A common problem with MoE is "expert collapse" — the router learns to send most tokens to a few experts while others go unused. This is addressed with an auxiliary load-balancing loss that encourages uniform expert utilisation:

Lbalance = α · N · Σi fi · Pi

Where fi is the fraction of tokens routed to expert i, and Pi is the average router probability for expert i.

ModelTotal ParamsActive ParamsExpertsTop-K
Mixtral 8x7B~47B~13B82
Mixtral 8x22B~176B~39B82
Grok-1314B~86B82
DeepSeek-V2236B~21B1606
DBRX132B~36B164

6. Pre-Training

Pre-training is the foundational phase where the model learns language understanding, world knowledge, and reasoning capabilities from massive text corpora. This is by far the most compute-intensive stage.

Training Objectives

Causal Language Modelling (CLM)

The standard objective for decoder-only models. Given a sequence of tokens (x1, ..., xT), minimise the negative log-likelihood of predicting each next token:

LCLM = - Σt=1T log P(xt | x1, ..., xt-1; θ)

Masked Language Modelling (MLM)

Used for encoder-only models (BERT). Randomly mask 15% of tokens and predict the originals. Produces bidirectional representations but is not directly usable for generation.

Span Corruption

Used for encoder-decoder models (T5). Replace random spans of text with sentinel tokens in the input; the decoder must predict the original spans. More sample-efficient than MLM.

Data Pipeline

The quality and composition of training data is arguably as important as architecture. A typical data pipeline:

Raw Web Crawl (Common Crawl, ~250B pages)
    │
    ├─► Language Detection (keep target languages)
    ├─► URL Filtering (remove adult/spam/low-quality domains)
    ├─► HTML Extraction (readability, trafilatura)
    ├─► Deduplication
    │   ├── Exact: Hash-based (SHA-256)
    │   ├── Near-duplicate: MinHash + LSH (Jaccard similarity)
    │   └── Substring: Suffix array-based
    ├─► Quality Filtering
    │   ├── Perplexity filter (KenLM trained on Wikipedia)
    │   ├── Heuristic rules (line length, special char ratio, etc.)
    │   └── Classifier-based (trained on human quality judgments)
    ├─► PII Removal (emails, phone numbers, SSNs, etc.)
    ├─► Toxicity / Safety Filtering
    └─► Data Mixing
        ├── Web text: ~60-80%
        ├── Code (GitHub): ~5-15%
        ├── Books: ~5-10%
        ├── Academic papers: ~3-5%
        ├── Wikipedia: ~2-5%
        ├── Curated Q&A: ~1-3%
        └── Math: ~1-3%

Training Data Sizes

ModelTraining TokensParameters
GPT-3300B175B
Chinchilla1.4T70B
LLaMA 1 (65B)1.4T65B
LLaMA 2 (70B)2T70B
LLaMA 3 (405B)15T+405B
Mistral Large~12T (est.)~123B

Scaling Laws

Research from Kaplan et al. (OpenAI) and Hoffmann et al. (DeepMind / Chinchilla) established power-law relationships between model performance and three factors: parameter count (N), training data (D), and compute budget (C).

Chinchilla Optimal Scaling

L(N, D) = E + A/Nα + B/Dβ

Key finding: for a fixed compute budget, the optimal model trains approximately 20 tokens per parameter. A 70B model should be trained on ~1.4T tokens. Many modern models are "over-trained" beyond this optimum to get better per-inference performance at the cost of more training compute.

The Compute–Performance Relationship

Compute (FLOPs) ≈ 6 × N × D
Where:
  N = number of parameters
  D = number of training tokens
  6  ≈ 2 (forward) + 4 (backward) multiplies per parameter per token

Example: Training a 70B model on 2T tokens
  = 6 × 70×10⁹ × 2×10¹² = 8.4 × 10²³ FLOPs
  ≈ 840 ZettaFLOPs
  On 2048 H100s at 50% MFU: ~23 days

Emergent Abilities

Certain capabilities appear abruptly as models scale — performing at chance level below some threshold and then jumping to high accuracy. Examples include multi-step arithmetic, chain-of-thought reasoning, and code execution tracing. The precise cause remains debated; some researchers argue it is partly an artefact of evaluation metrics.

7. Fine-Tuning & Alignment

After pre-training, the model is a capable but unrefined text predictor. The alignment pipeline transforms it into a helpful, harmless, and honest assistant.

┌──────────────┐ ┌──────────────┐ ┌──────────────┐ ┌──────────────┐ │ Pre-trained │ ──►│ Supervised │ ──►│ RLHF / │ ──►│ Aligned │ │ Base │ │ Fine-Tuning │ │ DPO │ │ Model │ │ Model │ │ (SFT) │ │ │ │ │ └──────────────┘ └──────────────┘ └──────────────┘ └──────────────┘ ↑ ↑ ↑ ↑ Trillions of 100K–1M curated Human preference Ready for web tokens demonstrations comparisons deployment

Supervised Fine-Tuning (SFT)

The model is trained on carefully curated (prompt, response) pairs demonstrating desired assistant behaviour. The training objective is still next-token prediction, but only on the response portion (the prompt tokens are masked from the loss).

# SFT loss computation (simplified)
for prompt, response in dataset:
    full_sequence = prompt + response
    tokens = tokenize(full_sequence)
    logits = model(tokens)

    # Only compute loss on response tokens
    prompt_len = len(tokenize(prompt))
    loss = cross_entropy(
        logits[prompt_len:-1],    # predictions for response tokens
        tokens[prompt_len+1:]     # target response tokens
    )

RLHF & DPO

Reinforcement Learning from Human Feedback (RLHF)

A three-step process:

  1. Reward Model Training: Human annotators rank model responses (e.g., response A is better than B). A reward model (often the same architecture as the LLM) is trained on these preferences using a Bradley-Terry pairwise loss:
    LRM = -log σ( r(x, yw) - r(x, yl) )
    where yw is the preferred response and yl is the rejected one.
  2. PPO Optimisation: The SFT model is further trained using Proximal Policy Optimisation to maximise the reward model's score while staying close to the SFT policy (via a KL divergence penalty):
    maxπ E[ r(x,y) - β · KL(π || πSFT) ]
  3. Iterative refinement: Collect new preference data on the RLHF model's outputs and repeat.

Direct Preference Optimisation (DPO)

DPO eliminates the separate reward model by directly optimising the policy on preference data. It reparameterises the RLHF objective to show that the optimal policy has a closed-form relationship with the reward:

LDPO = -log σ( β · [log π(yw|x)/πref(yw|x) - log π(yl|x)/πref(yl|x)] )

DPO is simpler (no reward model, no RL loop), more stable, and increasingly popular. Variants include IPO, KTO (which works with binary feedback rather than pairs), and ORPO.

Parameter-Efficient Fine-Tuning (PEFT)

Full fine-tuning updates all parameters, requiring significant memory and compute. PEFT methods freeze most weights and train a small number of additional parameters.

LoRA (Low-Rank Adaptation)

The most widely used PEFT method. For each target weight matrix W, LoRA freezes W and adds a low-rank decomposition:

W' = W + α · B · A     where A ∈ ℝr×d, B ∈ ℝd×r, r ≪ d
# LoRA applied to a linear layer
class LoRALinear(nn.Module):
    def __init__(self, original_linear, rank=16, alpha=32):
        super().__init__()
        self.original = original_linear  # frozen
        d_in, d_out = original_linear.weight.shape[1], original_linear.weight.shape[0]

        self.lora_A = nn.Parameter(torch.randn(rank, d_in) * 0.01)
        self.lora_B = nn.Parameter(torch.zeros(d_out, rank))
        self.scaling = alpha / rank

    def forward(self, x):
        original_out = self.original(x)
        lora_out = (x @ self.lora_A.T) @ self.lora_B.T * self.scaling
        return original_out + lora_out

# Typical ranks: 8–64. Trainable params: 0.1–1% of total model params.

QLoRA

Combines LoRA with quantization: the base model weights are stored in 4-bit NormalFloat (NF4) format, while LoRA adapters remain in full precision. This enables fine-tuning a 65B model on a single 48GB GPU.

Other PEFT Methods

  • Prefix Tuning: Prepends learnable "virtual tokens" to the key/value sequences in each attention layer.
  • Adapters: Insert small bottleneck MLP modules between existing layers.
  • IA3: Learns rescaling vectors for keys, values, and FFN activations — even fewer parameters than LoRA.

8. Inference & Decoding

At inference time, autoregressive LLMs generate text one token at a time. Each new token requires a full forward pass (though KV caching avoids recomputing attention for previous tokens). The choice of decoding strategy significantly affects output quality.

Inference Phases

There are two distinct phases:

  1. Prefill (Prompt Processing): The entire input prompt is processed in parallel. This is compute-bound — all tokens are processed simultaneously, fully utilising GPU parallelism. The KV cache is populated for all prompt tokens.
  2. Decode (Token Generation): Tokens are generated one at a time, each requiring a forward pass. This is memory-bandwidth-bound — only one token is processed per step, vastly under-utilising GPU compute but heavily using memory bandwidth to load the KV cache and model weights.
The Prefill–Decode Asymmetry Prefill is typically 10-100× faster per token than decode. For a 1000-token prompt and 500-token generation, prefill might take 200ms while decode takes 10 seconds. This asymmetry drives many inference optimisation strategies.

KV Cache

Without caching, generating token t would require recomputing attention over all previous tokens 1..t-1 — O(t²) total work for generating t tokens. The KV cache stores the key and value tensors from all previous tokens, so each new token only needs to compute its own Q, K, V and attend to the cached K, V vectors.

# KV Cache memory calculation
cache_size_per_layer = 2 × batch_size × seq_len × n_kv_heads × head_dim × dtype_bytes
total_kv_cache = n_layers × cache_size_per_layer

# Example: LLaMA 70B (GQA with 8 KV heads), batch=1, seq=4096, FP16
= 80 layers × 2 × 1 × 4096 × 8 × 128 × 2 bytes
= 80 × 2 × 4096 × 8 × 128 × 2
≈ 1.34 GB

# With batch=32:
≈ 42.9 GB  (often exceeds model weight memory!)

KV Cache Compression Techniques

  • GQA/MQA: Reduce number of KV heads (architectural)
  • Quantized KV Cache: Store K, V in INT8 or INT4 instead of FP16
  • Paged Attention (vLLM): Manage KV cache like virtual memory with fixed-size blocks, reducing fragmentation and enabling efficient batching
  • Sliding Window: Only cache the last W tokens (Mistral uses W=4096)
  • Token Eviction: H₂O (Heavy-Hitter Oracle) keeps only the most attended-to tokens in cache

Sampling Strategies

After the forward pass produces logits for the next token, a sampling strategy determines which token to actually select.

Temperature

P(xi) = exp(zi/T) / Σj exp(zj/T)

T=0 → deterministic (argmax). T=1 → unchanged distribution. T>1 → flatter (more random). T→0 → peaked (more deterministic).

Top-K Sampling

Restrict sampling to the K highest-probability tokens, renormalise, and sample. K=1 is greedy decoding.

Top-P (Nucleus) Sampling

Sort tokens by probability descending. Include tokens until their cumulative probability exceeds P. This adaptively adjusts the candidate set size based on the shape of the distribution — more candidates when the model is uncertain, fewer when it's confident.

Min-P Sampling

A newer alternative: set a minimum probability threshold as a fraction of the top token's probability. If the top token has probability 0.9 and min_p=0.1, only tokens with P ≥ 0.09 are kept. This adapts naturally to distribution shape.

Repetition Penalty

Multiplicatively reduces the logits of tokens that have already appeared in the context, discouraging repetitive outputs. Typical values: 1.0 (off) to 1.3.

# Combined sampling pipeline
def sample_next_token(logits, temperature=0.7, top_k=50, top_p=0.9):
    # 1. Apply temperature
    logits = logits / temperature

    # 2. Top-K filtering
    if top_k > 0:
        top_k_vals, _ = logits.topk(top_k)
        logits[logits < top_k_vals[..., -1:]] = float('-inf')

    # 3. Top-P (nucleus) filtering
    sorted_logits, sorted_indices = logits.sort(descending=True)
    cumulative_probs = sorted_logits.softmax(dim=-1).cumsum(dim=-1)
    mask = cumulative_probs - sorted_logits.softmax(dim=-1) >= top_p
    sorted_logits[mask] = float('-inf')
    logits = sorted_logits.scatter(-1, sorted_indices.argsort(-1), sorted_logits)

    # 4. Sample from filtered distribution
    probs = F.softmax(logits, dim=-1)
    return torch.multinomial(probs, num_samples=1)

Speculative Decoding

A powerful inference acceleration technique. Uses a small, fast "draft" model to generate K candidate tokens autoregressively, then verifies all K tokens in parallel using the large "target" model. Accepted tokens are kept; the first rejected token is resampled from the target model.

Draft Model (small, fast) Target Model (large, accurate) ───────────────────────── ─────────────────────────────── Generate K=5 draft tokens Verify all 5 in ONE forward pass "The cat sat on the" ✓ ✓ ✓ ✗ "The cat sat in ..." ↑ Accept 3, resample 4th Net: 4 tokens in ~1 large forward pass

Speculative decoding is mathematically lossless — it produces the exact same distribution as the target model. The speedup depends on the acceptance rate, which depends on how well the draft model approximates the target. Typical speedups: 2-3×.

9. Systems & Optimization

Quantization

Quantization reduces the numerical precision of model weights (and optionally activations) to decrease memory footprint and increase inference throughput. The key insight: LLM weights are typically well-behaved enough that reduced precision introduces minimal quality loss.

FormatBitsRange / DescriptionMemory (70B)
FP3232Full precision baseline~280 GB
FP16 / BF1616Standard training/inference precision~140 GB
INT88Per-channel / per-token scaling~70 GB
FP8 (E4M3)8Hardware-native on H100/MI300X~70 GB
INT4 / NF44NormalFloat for normally distributed weights~35 GB
GPTQ / AWQ4Calibration-based weight-only quant~35 GB
GGUF Q4_K_M4.8 avgMixed precision per-block (llama.cpp)~42 GB
2-bit (QuIP#)2Aggressive, some quality loss~18 GB

Weight-Only vs. Weight+Activation Quantization

  • Weight-only (W4A16): Weights in 4-bit, activations in FP16. Easy to implement, minimal quality loss. Dequantize weights on-the-fly during matrix multiplication. Used by GPTQ, AWQ, GGUF.
  • Weight+Activation (W8A8): Both weights and activations in INT8. Enables faster integer matrix multiplication on supported hardware. Requires careful handling of activation outliers (SmoothQuant, etc.).

Parallelism Strategies

Training and serving large models requires distributing computation across multiple GPUs/nodes.

Data Parallelism (DP)

Each GPU holds a full copy of the model. The batch is split across GPUs, each computes gradients on its shard, and gradients are all-reduced. Simple but requires each GPU to hold the full model. ZeRO (Zero Redundancy Optimizer) shards optimizer states, gradients, and optionally parameters across GPUs.

Tensor Parallelism (TP)

Splits individual weight matrices across GPUs. For a matrix multiplication Y = XW, the weight W is column-split across GPUs, each computes a partial result, and results are combined. Requires high-bandwidth interconnects (NVLink). Typical: TP across GPUs within a node.

Pipeline Parallelism (PP)

Assigns different layers to different GPUs. GPU 0 processes layers 0-19, GPU 1 processes layers 20-39, etc. The batch is split into micro-batches that flow through the pipeline. Typical: PP across nodes.

Sequence Parallelism (SP)

For very long sequences, splits the sequence dimension across GPUs. Ring Attention distributes the KV cache across GPUs, with each GPU computing attention on its local chunk and passing partial results around a ring topology.

Expert Parallelism (EP)

For MoE models, different experts reside on different GPUs. Tokens are routed to the appropriate GPU via all-to-all communication.

┌─────────────────────────────────────────────────────────────────────┐ │ TYPICAL PARALLELISM CONFIGURATION │ │ │ │ Node 0 (8× H100 NVLink) Node 1 (8× H100 NVLink) │ │ ┌─────────────────────┐ ┌─────────────────────┐ │ │ │ TP across 8 GPUs │ ──PP── │ TP across 8 GPUs │ │ │ │ (Layers 0-39) │ 400Gb │ (Layers 40-79) │ │ │ └─────────────────────┘ IB/RoCE └─────────────────────┘ │ │ ↕ DP (ZeRO-1) across node pairs │ └─────────────────────────────────────────────────────────────────────┘

FlashAttention

FlashAttention is an IO-aware attention algorithm that fuses the attention computation into a single GPU kernel, avoiding materialisation of the full N×N attention matrix in high-bandwidth memory (HBM). Instead, it tiles the computation and keeps intermediate results in fast SRAM.

Standard Attention Memory Access Pattern

# Standard: multiple HBM reads/writes
S = Q @ K.T           # Write N×N matrix to HBM
P = softmax(S)         # Read N×N from HBM, write N×N back
O = P @ V              # Read N×N from HBM

# Total HBM access: O(N² + Nd) — dominated by N² for long sequences

FlashAttention Memory Access Pattern

# FlashAttention: tiled computation in SRAM
for each tile of Q (block_q):
    for each tile of K, V (block_kv):
        # Load small tiles into SRAM (fast on-chip memory)
        # Compute partial attention in SRAM
        # Update running softmax statistics (online softmax)
        # Accumulate partial output
    # Write final output for this Q tile to HBM

# Total HBM access: O(N²d / M) where M = SRAM size
# For typical M, this is ~2-4× less than standard attention

FlashAttention achieves 2-4× wall-clock speedup and reduces memory from O(N²) to O(N), enabling much longer contexts without running out of memory. FlashAttention-2 further improved throughput by optimising the parallelism and work partitioning across GPU warps. FlashAttention-3 (H100) exploits hardware features like asynchronous memory copies and FP8 tensor cores.

10. Context & Memory

Extending the effective context window beyond what was used during training is an active research area. Several approaches exist:

RoPE Frequency Scaling

Since RoPE encodes position as rotation frequencies, the context window can be extended by modifying these frequencies:

  • Position Interpolation (PI): Scale all positions by the ratio of original/target length: θ' = θ × (Lorig/Ltarget). Requires short fine-tuning.
  • NTK-aware Scaling: Increase the base of the RoPE frequencies (e.g., from 10000 to 1000000), which spreads the rotations more evenly across the extended range.
  • YaRN (Yet another RoPE extensioN): Combines NTK scaling for high frequencies with PI for low frequencies, plus a temperature scaling factor. Often the best approach.

Retrieval-Augmented Generation (RAG)

Instead of stuffing everything into the context, RAG retrieves relevant documents from an external knowledge base and includes them in the prompt. Components:

  1. Document Chunking: Split documents into passages (256-1024 tokens)
  2. Embedding: Encode chunks with an embedding model (e.g., E5, BGE, OpenAI text-embedding-3)
  3. Vector Store: Index embeddings in a vector database (FAISS, Pinecone, Qdrant, Weaviate)
  4. Retrieval: At query time, embed the query, find top-K similar chunks via approximate nearest neighbour search
  5. Generation: Include retrieved chunks in the prompt and generate the answer

Sliding Window Attention

Each token only attends to the previous W tokens (e.g., W=4096). Information propagates across layers — after L layers, the effective receptive field is L × W tokens. Used in Mistral models. Much more memory-efficient for long sequences.

Ring Attention

Distributes the KV cache across multiple GPUs arranged in a ring. Each GPU holds a chunk of the sequence and computes attention on its local portion, while asynchronously sending/receiving KV blocks to/from neighbouring GPUs. This enables context lengths that scale linearly with the number of GPUs.

11. Safety & Alignment

Ensuring LLMs are safe, helpful, and honest is a multi-layered challenge spanning training, deployment, and ongoing monitoring.

Constitutional AI (CAI)

Developed by Anthropic, CAI uses a set of principles (a "constitution") to guide the model's behaviour. The process involves two phases: (1) self-critique, where the model generates and revises responses according to the principles, and (2) RL from AI Feedback (RLAIF), where the model itself evaluates response pairs according to the constitution, replacing some human labelling.

Safety Layers

  • Pre-training data filtering: Remove toxic, harmful, and private content from training data
  • SFT on safe demonstrations: Train the model on examples that demonstrate refusal of harmful requests
  • RLHF/DPO with safety preferences: Teach the model to prefer safe responses
  • Red-teaming: Adversarial testing to find failure modes
  • Input/output classifiers: Separate models that flag harmful inputs/outputs
  • System prompts: Instructions that establish behavioral boundaries

Known Challenges

  • Jailbreaking: Techniques that circumvent safety training (prompt injection, role-playing scenarios, encoded instructions)
  • Hallucination: Confident generation of false information — fundamentally rooted in the next-token prediction objective
  • Sycophancy: Agreeing with the user even when they are wrong, an artifact of RLHF optimising for human approval
  • Deceptive alignment: Theoretical concern where a model appears aligned during evaluation but behaves differently in deployment

12. Evaluation & Benchmarks

LLM evaluation is notoriously difficult. Common benchmarks and their focus areas:

BenchmarkFocusFormatNotes
MMLUWorld knowledge57 subjects, multiple choiceKnowledge breadth
MMLU-ProHarder knowledge10-option MC, harder questionsLess saturated
GSM8KMath reasoningGrade-school word problemsTests multi-step arithmetic
MATHAdvanced mathCompetition-level problemsLatex-formatted
HumanEvalCode generationPython function completion164 problems
SWE-BenchReal-world codingGitHub issue resolutionFull repo context
ARC-ChallengeScience reasoningGrade-school science MCRequires reasoning
HellaSwagCommonsenseSentence completionAdversarially filtered
WinoGrandeCommonsensePronoun resolutionBinary choice
TruthfulQAFactualityQuestions with common misconceptionsTests resistance to popular myths
MT-BenchChat qualityMulti-turn conversationsLLM-as-judge scoring
Chatbot ArenaOverall qualityBlind pairwise human preferencesElo rating system
GPQAExpert knowledgeGraduate-level questionsPhD-level difficulty

Evaluation Methods

  • Perplexity: exp(average negative log-likelihood). Measures how well the model predicts held-out text. Lower is better. Not directly indicative of downstream task performance.
  • N-shot prompting: Evaluate by providing N examples in the prompt and testing on new examples. 0-shot tests raw capability; few-shot tests in-context learning.
  • LLM-as-Judge: Use a strong LLM (e.g., GPT-4) to evaluate another model's outputs. Correlates well with human judgments but inherits judge model biases.
  • Human evaluation: Gold standard but expensive, slow, and hard to reproduce. Chatbot Arena's crowdsourced approach scales better than controlled studies.

13. Frontier Research Directions

Chain-of-Thought & Reasoning

Models like o1 and DeepSeek-R1 are trained to generate explicit reasoning chains before answering. This is achieved through RL on problems with verifiable answers (math, code) where the model learns to produce longer, more thorough reasoning traces that lead to correct solutions. The "thinking" tokens are generated at inference time, trading compute for accuracy.

State-Space Models & Alternatives

Mamba and similar architectures use structured state-space models (SSMs) instead of attention, achieving linear scaling with sequence length (O(n) vs O(n²)). These models use a selection mechanism (data-dependent parameters) that provides context-aware filtering. Hybrid architectures (Jamba) combine SSM layers with attention layers.

Multimodal Integration

Modern LLMs increasingly handle multiple modalities — text, images, audio, and video. Vision-language models (GPT-4V, Gemini, Claude) typically use a vision encoder (ViT) to convert image patches into embedding vectors that are interleaved with text token embeddings in the Transformer's input sequence.

Mechanistic Interpretability

Understanding what individual neurons, attention heads, and circuits compute inside LLMs. Key findings include induction heads (copying patterns), indirect object identification circuits, and factual recall mechanisms. Sparse autoencoders (SAEs) are used to decompose activations into interpretable features.

Test-Time Compute Scaling

Rather than only scaling training compute (bigger models, more data), recent work shows that scaling compute at inference time — via longer reasoning traces, search (beam search, tree search), self-verification, and ensembling — can dramatically improve performance, especially on hard reasoning problems.

Distillation

Training smaller "student" models to mimic larger "teacher" models. This can produce compact models with performance disproportionate to their size. Techniques include logit matching (KL divergence between student and teacher output distributions), feature matching (aligning intermediate representations), and on-policy distillation (generating data from the teacher).

14. Glossary

TermDefinition
AutoregressiveGenerating output one token at a time, conditioning on all previous tokens
BF16 (bfloat16)16-bit floating point with 8-bit exponent range of FP32 but reduced mantissa precision
Causal maskUpper-triangular mask that prevents tokens from attending to future positions
Cross-attentionAttention where Q comes from the decoder and K, V come from the encoder
dmodelThe hidden dimension / width of the residual stream
EmbeddingDense vector representation of a discrete token
FLOPFloating-point operation; FLOPs measure total compute, FLOP/s measures throughput
Gradient checkpointingTrading compute for memory by recomputing activations during backward pass instead of storing them
HallucinationModel generating confident but factually incorrect information
In-context learningModel learning to perform a task from examples provided in the prompt, without weight updates
KV cacheStored key and value tensors from previous tokens to avoid recomputation during autoregressive generation
LogitsRaw (unnormalised) output scores from the final linear layer before softmax
MFUModel FLOPs Utilisation — fraction of peak hardware FLOP/s actually achieved
PerplexityExponentiated average negative log-likelihood; measures how "surprised" the model is by held-out text
Residual streamThe hidden state vector that flows through the model, with each layer adding its contribution
SoftmaxFunction that converts logits to a probability distribution: softmax(zi) = ezi / Σezj
TokenThe basic unit of text processing — a subword, word, or character produced by the tokenizer
TransformerNeural network architecture based on self-attention, introduced in "Attention Is All You Need" (2017)