Keyboard shortcuts

Press or to navigate between chapters

Press S or / to search in the book

Press ? to show this help

Press Esc to hide this help

Model

Architecture

ComponentValue
Embedding dim16
Attention heads4 (head_dim=4)
Layers1
Context length16
Vocab27 (a-z + BOS)
Parameters4,192

Parameter breakdown

WeightShapeCount
wte (token embeddings)[27, 16]432
wpe (position embeddings)[16, 16]256
wq (4 heads)4 x [16, 4]256
wk (4 heads)4 x [16, 4]256
wv (4 heads)4 x [16, 4]256
wo (4 heads)4 x [4, 16]256
w_fc1 (MLP up)[16, 64]1,024
w_fc2 (MLP down)[64, 16]1,024
w_lm (LM head)[16, 27]432
Total4,192

Forward pass

tokens → one_hot → matmul(wte) + matmul(wpe) → RMSNorm
       → Attention(Q,K,V,O) + residual
       → RMSNorm → MLP(fc1 → ReLU → fc2) + residual
       → matmul(w_lm) → logits

Embedding lookup

Embeddings use a one-hot matmul instead of a dedicated Embedding layer. This ensures gradients flow through the standard matmul backward pass:

tok_emb = one_hot(tokens, 27) @ wte   // [n, 16]
pos_emb = one_hot(0..n, 16)  @ wpe   // [n, 16]
x = tok_emb + pos_emb

RMSNorm

Applied per-row with straight-through gradient estimation:

RMSNorm(x)_i = x_i / sqrt(mean(x_i^2) + 1e-5)

Uses -1e9 instead of -inf in the causal mask to satisfy the upstream softmax precondition contract (x.iter().all(|v| v.is_finite())).