GPT Adventures — Part 1: Baseline Implementation
Philosophy
The temptation, when starting a project like this, is to immediately reach for the shiny things. Flash attention, torch.compile, mixed precision, distributed training — these are all coming, but not yet. Grug — patron saint of the pragmatic developer — has a lot to say about the complexity demon and how it creeps in uninvited. The complexity demon is patient. It will wait.
The goal for Part 1 is a working GPT-2 Small: correct, readable, and explicit. Every matrix multiplication named. Every shape annotated. No clever abstractions obscuring what is actually happening. The implementation should be something a future, wiser version of me can look at whilst holding a profiler trace and immediately understand what is slow and why — without first having to reverse-engineer what the code is doing at all.
This means some deliberate simplifications relative to what a production implementation would look like. No weight tying between the embedding and unembedding matrices. ReLU rather than GELU. No KV cache. These are not mistakes to be fixed — they are placeholders to be swapped in one at a time, with a clear before-and-after measurement each time.
Model architecture
The top-level Transformer class takes five parameters: d_attention (the residual stream dimension), d_qkv (the per-head query/key/value dimension), d_ff (the feed-forward hidden dimension), n_layers, and max_ctx (the maximum sequence length). The number of attention heads is derived rather than specified directly: n_heads = d_attention // d_qkv, with an assert enforcing divisibility.
class Transformer(nn.Module):
def __init__(self, d_attention, d_qkv, d_ff, vocab_size, n_layers, max_ctx):
...
assert self.d_attention % self.d_qkv == 0
self.n_heads = self.d_attention // self.d_qkvFor GPT-2 Small, this gives us the familiar configuration: d_attention=768, d_qkv=64, d_ff=3072, 12 heads, 12 layers, max_ctx=1024.
The forward pass is deliberately sequential and annotated:
def forward(self, x):
bs, seq_len = x.shape
x = self.embeds(x) # bs, seq_len, d_attention
pos_embed = self.pos_embeds(self.pos_embed_range[:seq_len])
x = x + pos_embed # bs, seq_len, d_attention
for transformer_block in self.blocks:
x = transformer_block(x, self.causal_mask, self.is_inference)
logits = self.vocab_project(x) # bs, seq_len, vocab_size
return logitsLearned positional embeddings rather than sinusoidal ones — simpler to implement and standard for GPT-2. The causal mask is computed once at construction and registered as a buffer, so it moves to the correct device automatically when the model is moved.
Attention
The attention block has three separate projection matrices — Wq, Wk, Wv — one for each of queries, keys, and values. This is the most explicit possible implementation of multi-head attention: each projection is named, and their outputs are clearly distinct objects before the dot products are computed.
self.Wq = nn.Linear(d_attention, n_heads * d_qkv, bias=False)
self.Wk = nn.Linear(d_attention, n_heads * d_qkv, bias=False)
self.Wv = nn.Linear(d_attention, n_heads * d_qkv, bias=False)This is not how most production implementations do it — fusing Q, K, V into a single matrix multiplication is faster and the standard approach. But fusing them is an optimisation, and optimisations come later. For now, three matrices is clearer about what is actually happening.
The pre-norm architecture is used throughout: LayerNorm is applied to the input before the attention or feed-forward computation, not after. The residual connection then adds the un-normalised input back to the output. GPT-2 itself actually uses post-norm, but pre-norm has since become the standard in most subsequent architectures owing to better training stability at scale — and since the two are identical in implementation complexity, there is no reason not to use the better one.
The scaled dot-product attention is written out explicitly using einops.einsum, with dimension names spelled out in full:
atten_dot_prods = einops.einsum(
einops.rearrange(q_cat, 'bs seq_len (n_heads d_qkv) -> bs seq_len n_heads d_qkv', n_heads=self.n_heads),
einops.rearrange(k_cat, 'bs seq_len (n_heads d_qkv) -> bs seq_len n_heads d_qkv', n_heads=self.n_heads),
'bs seq_q n_heads d_qkv, bs seq_k n_heads d_qkv -> bs seq_q seq_k n_heads'
) / np.sqrt(self.d_qkv)The einops dimension strings serve as inline documentation: it is immediately clear that the output has shape (batch, seq_q, seq_k, n_heads). There is no ambiguity about which axis is which — and when something goes wrong with shapes, the rearrange strings tell you exactly where.
The causal mask is an upper-triangular matrix of -inf values, added to the attention logits before softmax. Positions that should not be attended to become -inf before softmax, collapsing to zero attention weight. In this initial version it is constructed and moved to the correct device inside the forward pass — a redundant allocation on every call, which the profiler will duly notice.
Feed-forward block
The MLP is the canonical two-layer design: up-projection to d_ff (4× the residual stream dimension), ReLU activation, down-projection back to d_attention. Again, pre-norm and residual connection:
def forward(self, x):
x_res = x
x = self.layer_norm(x)
x = self.Wup(x) # bs, seq_len, d_ff
x = nn.functional.relu(x)
x = self.Wdp(x) # bs, seq_len, d_attention
return x + x_resGELU would be more faithful to the original GPT-2 paper, but ReLU is simpler and easier to reason about. Swapping it in later will be a one-line change with a measurable effect.
Data pipeline
Tokenisation uses tiktoken with GPT-2’s BPE vocabulary (~50k tokens). Each line of the corpus is encoded with <|endoftext|> delimiters, producing a flat token stream that is then chunked into fixed-length sequences of max_ctx tokens:
n_samples = len(data) // self.max_seq_len
self.data_samples = data[:n_samples * self.max_seq_len].view(n_samples, self.max_seq_len)Targets are the inputs offset by one position — next-token prediction. The tokenised corpus is cached to disk after the first run to avoid re-tokenising on every launch.
The dataset also supports splitting off a validation set via split_valid_from_train, which uses a random permutation to pull out a fraction of samples, removing them from the training set in the process.
Training loop
The training loop is straightforward: forward pass, cross-entropy loss, backward pass, optimiser step. AdamW with the hyperparameters from the original GPT-2 paper: lr=6e-4, β=(0.9, 0.95), weight_decay=0.01.
The learning rate schedule follows a cosine decay with linear warmup, implemented from scratch as a custom LRScheduler. The warmup ramps linearly from lr_min to lr_max over the first wu_fraction of total steps, then follows a cosine curve back down:
if self.current_step <= self.wu_steps:
return [(self.current_step / self.wu_steps) * adjusted_scale + self.lr_min]
else:
adjusted_cosine_step = (self.current_step - self.wu_steps) / (self.total_steps - self.wu_steps)
return [math.cos(adjusted_cosine_step * math.pi / 2) * adjusted_scale + self.lr_min]Implementing the scheduler from scratch rather than using a library implementation was a deliberate choice — understanding what the learning rate is actually doing at each step is useful context for later, when we are trying to understand why a training run behaves the way it does.
What is deliberately left out
To keep things explicit and measurable, several things are intentionally absent from this implementation:
- No
torch.compile— compilation obscures what kernels are being launched, which makes profiling harder - No mixed precision — fp32 throughout, so numerical behaviour is straightforward
- No flash attention — the naive O(N²) attention is written out in full; the inefficiency is the point
- No KV cache — inference re-computes all positions from scratch every time; the scope of this project is training optimisation, so inference efficiency is deliberately set aside to keep things focused
- Single GPU only — distribution comes in Part 4
Each of these will be added in a later part, with a clear measurement of its effect. For now, the complexity demon remains outside the codebase. But it will come — Grug know it always does.
