Created
February 25, 2026 02:36
-
-
Save cosminscn/e84e98a6d247299acea065167ece2753 to your computer and use it in GitHub Desktop.
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
| """58-param nanoGPT that adds any two 10-digit numbers. No training. | |
| Down from 130 params. HONEST counting: every stored numerical value is an | |
| nn.Parameter. Zero buffers. The only "free" things are structural choices | |
| (which dim connects where), control flow, and pure math (sin/cos/arange | |
| for PE generation — same convention as the original). | |
| Parameter budget: | |
| wte.A (10×1) + wte.B (1×4) = 14 [factorized embedding] | |
| Q (2 angles + 1 scale) = 3 [rotation-parameterized] | |
| K (SlimLinear 4×2) = 8 [reads only sin/cos dims] | |
| V (Rank1Linear u4 + v4) = 8 [rank-1, selects digit dim] | |
| c_proj (2 sparse scalars) = 2 [sparse projection] | |
| MLP c_fc (2×4 rows + 4 biases) = 12 [rank-2 shared rows] | |
| MLP c_proj (Rank1Linear u4 + v4) = 8 [rank-1] | |
| LM head (3 scalars) = 3 [parametric parabolic] | |
| ────────────────────────────────────────── | |
| TOTAL: 58 params | |
| Savings vs original (130 → 58 = -72 params): | |
| c_attn 48→19: Rotation Q (3) + SlimLinear K (8) + Rank1 V (8) | |
| c_proj 16→2: Sparse (2 stored values) | |
| c_fc 20→12: Shared-row rank-2 weight | |
| head 24→3: Parametric parabolic (u[v]=2v, bias=-v² from 2 coeffs) | |
| """ | |
| import math | |
| from dataclasses import dataclass | |
| import torch | |
| import torch.nn as nn | |
| from torch.nn import functional as F | |
| # ─── Building blocks ──────────────────────────────────────────────── | |
| class FactorizedEmbedding(nn.Module): | |
| """Rank-1 token embedding: embed(x) = A[x] @ B. Params: vocab + emb_dim.""" | |
| def __init__(self, vocab_size, emb_dim, rank=1): | |
| super().__init__() | |
| self.A = nn.Parameter(torch.zeros(vocab_size, rank)) | |
| self.B = nn.Parameter(torch.zeros(rank, emb_dim)) | |
| def forward(self, x): | |
| return self.A[x] @ self.B | |
| class Rank1Linear(nn.Module): | |
| """W = u ⊗ v (rank-1). Params: out_features + in_features.""" | |
| def __init__(self, in_features, out_features): | |
| super().__init__() | |
| self.u = nn.Parameter(torch.zeros(out_features, 1)) | |
| self.v = nn.Parameter(torch.zeros(1, in_features)) | |
| def forward(self, x): | |
| return (x @ self.v.T) @ self.u.T | |
| class SlimLinear(nn.Module): | |
| """Linear that reads only selected input dims. Params: out × len(dims).""" | |
| def __init__(self, out_features, input_dims): | |
| super().__init__() | |
| self.input_dims = input_dims | |
| self.weight = nn.Parameter(torch.zeros(out_features, len(input_dims))) | |
| def forward(self, x): | |
| return x[..., self.input_dims] @ self.weight.T | |
| class RotationQ(nn.Module): | |
| """Q projection from 2 rotation angles + 1 scale. 3 params → 4×4 matrix.""" | |
| def __init__(self): | |
| super().__init__() | |
| self.angle_h0 = nn.Parameter(torch.zeros(1)) | |
| self.angle_h1 = nn.Parameter(torch.zeros(1)) | |
| self.scale = nn.Parameter(torch.zeros(1)) | |
| def forward(self, x): | |
| S, a0, a1 = self.scale, self.angle_h0, self.angle_h1 | |
| W = x.new_zeros(4, 4) | |
| W[0, 1] = -torch.cos(a0) * S; W[0, 2] = torch.sin(a0) * S | |
| W[1, 1] = torch.sin(a0) * S; W[1, 2] = torch.cos(a0) * S | |
| W[2, 1] = -torch.cos(a1) * S; W[2, 2] = torch.sin(a1) * S | |
| W[3, 1] = torch.sin(a1) * S; W[3, 2] = torch.cos(a1) * S | |
| return x @ W.T | |
| class SparseProj(nn.Module): | |
| """Projection with 2 non-zero entries at fixed positions. 2 params.""" | |
| def __init__(self): | |
| super().__init__() | |
| self.val0 = nn.Parameter(torch.zeros(1)) # maps attn_out[:,0] → dim 3 | |
| self.val1 = nn.Parameter(torch.zeros(1)) # maps attn_out[:,2] → dim 1 | |
| def forward(self, x): | |
| out = torch.zeros_like(x) | |
| out[..., 3] = self.val0 * x[..., 0] | |
| out[..., 1] = self.val1 * x[..., 2] | |
| return out | |
| # ─── Attention (21 params) ────────────────────────────────────────── | |
| class Attention(nn.Module): | |
| def __init__(self, n_embd=4, n_head=2): | |
| super().__init__() | |
| self.n_head, self.head_dim = n_head, n_embd // n_head | |
| self.q_proj = RotationQ() # 3 | |
| self.k_proj = SlimLinear(n_embd, [1, 2]) # 8 | |
| self.v_proj = Rank1Linear(n_embd, n_embd) # 8 | |
| self.c_proj = SparseProj() # 2 | |
| def forward(self, x): | |
| B, T, C = x.size() | |
| hd = self.head_dim | |
| q = self.q_proj(x).view(B, T, self.n_head, hd).transpose(1, 2) | |
| k = self.k_proj(x).view(B, T, self.n_head, hd).transpose(1, 2) | |
| v = self.v_proj(x).view(B, T, self.n_head, hd).transpose(1, 2) | |
| y = F.scaled_dot_product_attention(q, k, v, is_causal=True) | |
| y = y.transpose(1, 2).contiguous().view(B, T, C) | |
| return self.c_proj(y) | |
| # ─── MLP (20 params) ──────────────────────────────────────────────── | |
| class SharedRowMLP(nn.Module): | |
| """Rank-2 c_fc (rows shared pairwise) + rank-1 c_proj.""" | |
| def __init__(self, n_embd=4, n_hidden=4): | |
| super().__init__() | |
| self.row0 = nn.Parameter(torch.zeros(n_embd)) # 4 | |
| self.row1 = nn.Parameter(torch.zeros(n_embd)) # 4 | |
| self.fc_bias = nn.Parameter(torch.zeros(n_hidden)) # 4 | |
| self.c_proj = Rank1Linear(n_hidden, n_embd) # 8 | |
| def forward(self, x): | |
| W = torch.stack([self.row0, self.row0, self.row1, self.row1]) | |
| h = F.relu(x @ W.T + self.fc_bias) | |
| return self.c_proj(h) | |
| # ─── LM Head (3 params) ───────────────────────────────────────────── | |
| class ParabolicHead(nn.Module): | |
| """logit(d) = lin·d·x[3]·dim_w + quad·d². Argmax recovers nearest int.""" | |
| def __init__(self, vocab_size=10): | |
| super().__init__() | |
| self.vocab_size = vocab_size | |
| self.lin = nn.Parameter(torch.zeros(1)) # 1 | |
| self.quad = nn.Parameter(torch.zeros(1)) # 1 | |
| self.dim_w = nn.Parameter(torch.zeros(1)) # 1 | |
| def forward(self, x): | |
| scalar = x[..., 3:4] * self.dim_w | |
| d = torch.arange(self.vocab_size, device=x.device, dtype=x.dtype).view(1, 1, -1) | |
| return self.lin * d * scalar + self.quad * d ** 2 | |
| # ─── GPT ───────────────────────────────────────────────────────────── | |
| @dataclass | |
| class GPTConfig: | |
| block_size: int = 35 | |
| vocab_size: int = 10 | |
| n_embd: int = 4 | |
| n_head: int = 2 | |
| mlp_hidden: int = 4 | |
| class GPT(nn.Module): | |
| def __init__(self, config): | |
| super().__init__() | |
| self.config = config | |
| self.wte = FactorizedEmbedding(config.vocab_size, config.n_embd) # 14 | |
| self.attn = Attention(config.n_embd, config.n_head) # 21 | |
| self.mlp = SharedRowMLP(config.n_embd, config.mlp_hidden) # 20 | |
| self.lm_head = ParabolicHead(config.vocab_size) # 3 | |
| def generate_pe(self, seq_len, device): | |
| """Sinusoidal PE from formula — no stored values (same as original).""" | |
| pe = torch.zeros(seq_len, self.config.n_embd, device=device) | |
| pos = torch.arange(seq_len, device=device, dtype=torch.float32) | |
| th = 2 * math.pi / 11 | |
| amp = torch.where(pos <= 21, 100.0, 1.0) | |
| pe[:, 1] = amp * torch.sin(pos * th) | |
| pe[:, 2] = amp * torch.cos(pos * th) | |
| return pe | |
| def forward(self, idx, targets=None): | |
| x = self.wte(idx) + self.generate_pe(idx.size(1), idx.device) | |
| x = x + self.attn(x) | |
| x = x + self.mlp(x) | |
| return self.lm_head(x[:, [-1], :]), None | |
| # ─── Weight injection ──────────────────────────────────────────────── | |
| def build_adder(): | |
| config = GPTConfig() | |
| model = GPT(config) | |
| th = 2 * math.pi / 11 | |
| with torch.no_grad(): | |
| # Embedding: digit v → [v, 0, 0, 0] | |
| for v in range(10): | |
| model.wte.A[v, 0] = float(v) | |
| model.wte.B[0, :] = torch.tensor([1.0, 0.0, 0.0, 0.0]) | |
| # Q: rotation angles that create resonance at offsets 8 and 9 | |
| model.attn.q_proj.angle_h0.fill_(8 * th) | |
| model.attn.q_proj.angle_h1.fill_(9 * th) | |
| model.attn.q_proj.scale.fill_(100.0) | |
| # K: identity on sin/cos dims for both heads | |
| model.attn.k_proj.weight.copy_(torch.tensor([ | |
| [1.0, 0.0], [0.0, 1.0], # head 0 | |
| [1.0, 0.0], [0.0, 1.0], # head 1 | |
| ])) | |
| # V: rank-1 selecting dim 0 for both heads | |
| model.attn.v_proj.u.copy_(torch.tensor([[1.0], [0.0], [1.0], [0.0]])) | |
| model.attn.v_proj.v.copy_(torch.tensor([[1.0, 0.0, 0.0, 0.0]])) | |
| # c_proj: route head outputs to carry computation dims | |
| model.attn.c_proj.val0.fill_(2.0) | |
| model.attn.c_proj.val1.fill_(2.0) | |
| # MLP: carry detection and propagation | |
| model.mlp.row0.copy_(torch.tensor([-100.0, 100.0, 0.0, 0.0])) | |
| model.mlp.row1.copy_(torch.tensor([-10.0, 10.0, 0.0, 1000.0])) | |
| model.mlp.fc_bias.copy_(torch.tensor([-50.0, -150.0, -9005.0, -9015.0])) | |
| model.mlp.c_proj.u.zero_() | |
| model.mlp.c_proj.v.zero_() | |
| model.mlp.c_proj.u[3, 0] = 1.0 | |
| model.mlp.c_proj.v[0, :] = torch.tensor([0.01, -0.01, -1.0, 1.0]) | |
| # LM head: parabolic decoding | |
| model.lm_head.lin.fill_(2.0) | |
| model.lm_head.quad.fill_(-1.0) | |
| model.lm_head.dim_w.fill_(1.0) | |
| return model | |
| # ─── Test harness ──────────────────────────────────────────────────── | |
| def test(model, a, b): | |
| tok = {'+': 0, '=': 0} | |
| seq = [int(c) if c.isdigit() else tok[c] for c in f"{a:010d}+{b:010d}="] | |
| model.eval() | |
| with torch.no_grad(): | |
| for _ in range(11): | |
| logits, _ = model(torch.tensor([seq])) | |
| seq.append(logits[0, -1].argmax().item()) | |
| result = int("".join(str(t) for t in seq[22:])[::-1]) | |
| ok = result == a + b | |
| print(f" {a:>13,d} + {b:>13,d} = {result:>13,d} {'✅' if ok else '❌'}") | |
| return ok | |
| if __name__ == "__main__": | |
| model = build_adder() | |
| params = sum(p.numel() for p in model.parameters()) | |
| bufs = sum(b.numel() for b in model.buffers()) | |
| print(f"\n Optimized NanoGPT Adder: {params} params, {bufs} buffers\n") | |
| for name, p in model.named_parameters(): | |
| print(f" {name:40s} {str(list(p.shape)):>10s} = {p.numel():>3d}") | |
| print(f" {'TOTAL':40s} {'':>10s} = {params:>3d}") | |
| assert bufs == 0, f"Expected 0 buffers, got {bufs}" | |
| print() | |
| tests = [ | |
| (5, 5), (555, 445), (99999, 1), (19492, 23919), | |
| (9999999999, 1), (1234567890, 987654321), (0, 0), | |
| (1111111111, 8888888889), | |
| (9999999999, 9999999999), (1000000000, 0), (4567891234, 5432108766), | |
| ] | |
| passed = sum(test(model, a, b) for a, b in tests) | |
| print(f"\n {passed}/{len(tests)} passed") |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment