Skip to content

Instantly share code, notes, and snippets.

@cosminscn
Created February 25, 2026 02:36
Show Gist options
  • Select an option

  • Save cosminscn/e84e98a6d247299acea065167ece2753 to your computer and use it in GitHub Desktop.

Select an option

Save cosminscn/e84e98a6d247299acea065167ece2753 to your computer and use it in GitHub Desktop.
"""58-param nanoGPT that adds any two 10-digit numbers. No training.
Down from 130 params. HONEST counting: every stored numerical value is an
nn.Parameter. Zero buffers. The only "free" things are structural choices
(which dim connects where), control flow, and pure math (sin/cos/arange
for PE generation — same convention as the original).
Parameter budget:
wte.A (10×1) + wte.B (1×4) = 14 [factorized embedding]
Q (2 angles + 1 scale) = 3 [rotation-parameterized]
K (SlimLinear 4×2) = 8 [reads only sin/cos dims]
V (Rank1Linear u4 + v4) = 8 [rank-1, selects digit dim]
c_proj (2 sparse scalars) = 2 [sparse projection]
MLP c_fc (2×4 rows + 4 biases) = 12 [rank-2 shared rows]
MLP c_proj (Rank1Linear u4 + v4) = 8 [rank-1]
LM head (3 scalars) = 3 [parametric parabolic]
──────────────────────────────────────────
TOTAL: 58 params
Savings vs original (130 → 58 = -72 params):
c_attn 48→19: Rotation Q (3) + SlimLinear K (8) + Rank1 V (8)
c_proj 16→2: Sparse (2 stored values)
c_fc 20→12: Shared-row rank-2 weight
head 24→3: Parametric parabolic (u[v]=2v, bias=-v² from 2 coeffs)
"""
import math
from dataclasses import dataclass
import torch
import torch.nn as nn
from torch.nn import functional as F
# ─── Building blocks ────────────────────────────────────────────────
class FactorizedEmbedding(nn.Module):
"""Rank-1 token embedding: embed(x) = A[x] @ B. Params: vocab + emb_dim."""
def __init__(self, vocab_size, emb_dim, rank=1):
super().__init__()
self.A = nn.Parameter(torch.zeros(vocab_size, rank))
self.B = nn.Parameter(torch.zeros(rank, emb_dim))
def forward(self, x):
return self.A[x] @ self.B
class Rank1Linear(nn.Module):
"""W = u ⊗ v (rank-1). Params: out_features + in_features."""
def __init__(self, in_features, out_features):
super().__init__()
self.u = nn.Parameter(torch.zeros(out_features, 1))
self.v = nn.Parameter(torch.zeros(1, in_features))
def forward(self, x):
return (x @ self.v.T) @ self.u.T
class SlimLinear(nn.Module):
"""Linear that reads only selected input dims. Params: out × len(dims)."""
def __init__(self, out_features, input_dims):
super().__init__()
self.input_dims = input_dims
self.weight = nn.Parameter(torch.zeros(out_features, len(input_dims)))
def forward(self, x):
return x[..., self.input_dims] @ self.weight.T
class RotationQ(nn.Module):
"""Q projection from 2 rotation angles + 1 scale. 3 params → 4×4 matrix."""
def __init__(self):
super().__init__()
self.angle_h0 = nn.Parameter(torch.zeros(1))
self.angle_h1 = nn.Parameter(torch.zeros(1))
self.scale = nn.Parameter(torch.zeros(1))
def forward(self, x):
S, a0, a1 = self.scale, self.angle_h0, self.angle_h1
W = x.new_zeros(4, 4)
W[0, 1] = -torch.cos(a0) * S; W[0, 2] = torch.sin(a0) * S
W[1, 1] = torch.sin(a0) * S; W[1, 2] = torch.cos(a0) * S
W[2, 1] = -torch.cos(a1) * S; W[2, 2] = torch.sin(a1) * S
W[3, 1] = torch.sin(a1) * S; W[3, 2] = torch.cos(a1) * S
return x @ W.T
class SparseProj(nn.Module):
"""Projection with 2 non-zero entries at fixed positions. 2 params."""
def __init__(self):
super().__init__()
self.val0 = nn.Parameter(torch.zeros(1)) # maps attn_out[:,0] → dim 3
self.val1 = nn.Parameter(torch.zeros(1)) # maps attn_out[:,2] → dim 1
def forward(self, x):
out = torch.zeros_like(x)
out[..., 3] = self.val0 * x[..., 0]
out[..., 1] = self.val1 * x[..., 2]
return out
# ─── Attention (21 params) ──────────────────────────────────────────
class Attention(nn.Module):
def __init__(self, n_embd=4, n_head=2):
super().__init__()
self.n_head, self.head_dim = n_head, n_embd // n_head
self.q_proj = RotationQ() # 3
self.k_proj = SlimLinear(n_embd, [1, 2]) # 8
self.v_proj = Rank1Linear(n_embd, n_embd) # 8
self.c_proj = SparseProj() # 2
def forward(self, x):
B, T, C = x.size()
hd = self.head_dim
q = self.q_proj(x).view(B, T, self.n_head, hd).transpose(1, 2)
k = self.k_proj(x).view(B, T, self.n_head, hd).transpose(1, 2)
v = self.v_proj(x).view(B, T, self.n_head, hd).transpose(1, 2)
y = F.scaled_dot_product_attention(q, k, v, is_causal=True)
y = y.transpose(1, 2).contiguous().view(B, T, C)
return self.c_proj(y)
# ─── MLP (20 params) ────────────────────────────────────────────────
class SharedRowMLP(nn.Module):
"""Rank-2 c_fc (rows shared pairwise) + rank-1 c_proj."""
def __init__(self, n_embd=4, n_hidden=4):
super().__init__()
self.row0 = nn.Parameter(torch.zeros(n_embd)) # 4
self.row1 = nn.Parameter(torch.zeros(n_embd)) # 4
self.fc_bias = nn.Parameter(torch.zeros(n_hidden)) # 4
self.c_proj = Rank1Linear(n_hidden, n_embd) # 8
def forward(self, x):
W = torch.stack([self.row0, self.row0, self.row1, self.row1])
h = F.relu(x @ W.T + self.fc_bias)
return self.c_proj(h)
# ─── LM Head (3 params) ─────────────────────────────────────────────
class ParabolicHead(nn.Module):
"""logit(d) = lin·d·x[3]·dim_w + quad·d². Argmax recovers nearest int."""
def __init__(self, vocab_size=10):
super().__init__()
self.vocab_size = vocab_size
self.lin = nn.Parameter(torch.zeros(1)) # 1
self.quad = nn.Parameter(torch.zeros(1)) # 1
self.dim_w = nn.Parameter(torch.zeros(1)) # 1
def forward(self, x):
scalar = x[..., 3:4] * self.dim_w
d = torch.arange(self.vocab_size, device=x.device, dtype=x.dtype).view(1, 1, -1)
return self.lin * d * scalar + self.quad * d ** 2
# ─── GPT ─────────────────────────────────────────────────────────────
@dataclass
class GPTConfig:
block_size: int = 35
vocab_size: int = 10
n_embd: int = 4
n_head: int = 2
mlp_hidden: int = 4
class GPT(nn.Module):
def __init__(self, config):
super().__init__()
self.config = config
self.wte = FactorizedEmbedding(config.vocab_size, config.n_embd) # 14
self.attn = Attention(config.n_embd, config.n_head) # 21
self.mlp = SharedRowMLP(config.n_embd, config.mlp_hidden) # 20
self.lm_head = ParabolicHead(config.vocab_size) # 3
def generate_pe(self, seq_len, device):
"""Sinusoidal PE from formula — no stored values (same as original)."""
pe = torch.zeros(seq_len, self.config.n_embd, device=device)
pos = torch.arange(seq_len, device=device, dtype=torch.float32)
th = 2 * math.pi / 11
amp = torch.where(pos <= 21, 100.0, 1.0)
pe[:, 1] = amp * torch.sin(pos * th)
pe[:, 2] = amp * torch.cos(pos * th)
return pe
def forward(self, idx, targets=None):
x = self.wte(idx) + self.generate_pe(idx.size(1), idx.device)
x = x + self.attn(x)
x = x + self.mlp(x)
return self.lm_head(x[:, [-1], :]), None
# ─── Weight injection ────────────────────────────────────────────────
def build_adder():
config = GPTConfig()
model = GPT(config)
th = 2 * math.pi / 11
with torch.no_grad():
# Embedding: digit v → [v, 0, 0, 0]
for v in range(10):
model.wte.A[v, 0] = float(v)
model.wte.B[0, :] = torch.tensor([1.0, 0.0, 0.0, 0.0])
# Q: rotation angles that create resonance at offsets 8 and 9
model.attn.q_proj.angle_h0.fill_(8 * th)
model.attn.q_proj.angle_h1.fill_(9 * th)
model.attn.q_proj.scale.fill_(100.0)
# K: identity on sin/cos dims for both heads
model.attn.k_proj.weight.copy_(torch.tensor([
[1.0, 0.0], [0.0, 1.0], # head 0
[1.0, 0.0], [0.0, 1.0], # head 1
]))
# V: rank-1 selecting dim 0 for both heads
model.attn.v_proj.u.copy_(torch.tensor([[1.0], [0.0], [1.0], [0.0]]))
model.attn.v_proj.v.copy_(torch.tensor([[1.0, 0.0, 0.0, 0.0]]))
# c_proj: route head outputs to carry computation dims
model.attn.c_proj.val0.fill_(2.0)
model.attn.c_proj.val1.fill_(2.0)
# MLP: carry detection and propagation
model.mlp.row0.copy_(torch.tensor([-100.0, 100.0, 0.0, 0.0]))
model.mlp.row1.copy_(torch.tensor([-10.0, 10.0, 0.0, 1000.0]))
model.mlp.fc_bias.copy_(torch.tensor([-50.0, -150.0, -9005.0, -9015.0]))
model.mlp.c_proj.u.zero_()
model.mlp.c_proj.v.zero_()
model.mlp.c_proj.u[3, 0] = 1.0
model.mlp.c_proj.v[0, :] = torch.tensor([0.01, -0.01, -1.0, 1.0])
# LM head: parabolic decoding
model.lm_head.lin.fill_(2.0)
model.lm_head.quad.fill_(-1.0)
model.lm_head.dim_w.fill_(1.0)
return model
# ─── Test harness ────────────────────────────────────────────────────
def test(model, a, b):
tok = {'+': 0, '=': 0}
seq = [int(c) if c.isdigit() else tok[c] for c in f"{a:010d}+{b:010d}="]
model.eval()
with torch.no_grad():
for _ in range(11):
logits, _ = model(torch.tensor([seq]))
seq.append(logits[0, -1].argmax().item())
result = int("".join(str(t) for t in seq[22:])[::-1])
ok = result == a + b
print(f" {a:>13,d} + {b:>13,d} = {result:>13,d} {'✅' if ok else '❌'}")
return ok
if __name__ == "__main__":
model = build_adder()
params = sum(p.numel() for p in model.parameters())
bufs = sum(b.numel() for b in model.buffers())
print(f"\n Optimized NanoGPT Adder: {params} params, {bufs} buffers\n")
for name, p in model.named_parameters():
print(f" {name:40s} {str(list(p.shape)):>10s} = {p.numel():>3d}")
print(f" {'TOTAL':40s} {'':>10s} = {params:>3d}")
assert bufs == 0, f"Expected 0 buffers, got {bufs}"
print()
tests = [
(5, 5), (555, 445), (99999, 1), (19492, 23919),
(9999999999, 1), (1234567890, 987654321), (0, 0),
(1111111111, 8888888889),
(9999999999, 9999999999), (1000000000, 0), (4567891234, 5432108766),
]
passed = sum(test(model, a, b) for a, b in tests)
print(f"\n {passed}/{len(tests)} passed")
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment