Skip to content

Instantly share code, notes, and snippets.

@chaliy
Forked from karpathy/microgpt.py
Created February 13, 2026 03:38
Show Gist options
  • Select an option

  • Save chaliy/677b9feb95ca7b655790fbfcbc17c5dd to your computer and use it in GitHub Desktop.

Select an option

Save chaliy/677b9feb95ca7b655790fbfcbc17c5dd to your computer and use it in GitHub Desktop.
microgpt
/// The most atomic way to train and inference a GPT in pure, dependency-free Rust.
/// This file is the complete algorithm. Everything else is just efficiency.
///
/// Ported from @karpathy's microgpt.py
///
/// Compile: rustc -O microgpt.rs
/// Run: ./microgpt (expects input.txt in current directory)
use std::cell::RefCell;
use std::collections::HashSet;
use std::fs;
use std::rc::Rc;
// ---- RNG (xorshift64 + Box-Muller + Fisher-Yates) ----
// Can be replaced with `rand` + `rand_distr` crates via Cargo.
struct Rng { state: u64 }
impl Rng {
fn new(seed: u64) -> Self { Self { state: if seed == 0 { 1 } else { seed } } }
fn next_u64(&mut self) -> u64 {
self.state ^= self.state << 13;
self.state ^= self.state >> 7;
self.state ^= self.state << 17;
self.state
}
fn next_f64(&mut self) -> f64 { (self.next_u64() >> 11) as f64 / ((1u64 << 53) as f64) }
fn gauss(&mut self, mean: f64, std: f64) -> f64 {
let u1 = self.next_f64().max(1e-15);
let u2 = self.next_f64();
mean + std * (-2.0 * u1.ln()).sqrt() * (2.0 * std::f64::consts::PI * u2).cos()
}
fn shuffle<T>(&mut self, v: &mut [T]) {
for i in (1..v.len()).rev() { v.swap(i, (self.next_u64() as usize) % (i + 1)); }
}
fn weighted_choice(&mut self, weights: &[f64]) -> usize {
let mut r = self.next_f64() * weights.iter().sum::<f64>();
for (i, &w) in weights.iter().enumerate() { r -= w; if r <= 0.0 { return i; } }
weights.len() - 1
}
}
// ---- Autograd: Value node with computation graph for automatic differentiation ----
struct ValueInner { data: f64, grad: f64, children: Vec<Val>, local_grads: Vec<f64> }
#[derive(Clone)]
struct Val(Rc<RefCell<ValueInner>>);
impl Val {
fn new(data: f64) -> Self {
Self(Rc::new(RefCell::new(ValueInner { data, grad: 0.0, children: vec![], local_grads: vec![] })))
}
fn from(data: f64, children: Vec<Val>, local_grads: Vec<f64>) -> Self {
Self(Rc::new(RefCell::new(ValueInner { data, grad: 0.0, children, local_grads })))
}
fn d(&self) -> f64 { self.0.borrow().data }
fn g(&self) -> f64 { self.0.borrow().grad }
fn set_d(&self, v: f64) { self.0.borrow_mut().data = v; }
fn set_g(&self, v: f64) { self.0.borrow_mut().grad = v; }
fn add(&self, other: &Val) -> Val {
Val::from(self.d() + other.d(), vec![self.clone(), other.clone()], vec![1.0, 1.0])
}
fn mul(&self, other: &Val) -> Val {
Val::from(self.d() * other.d(), vec![self.clone(), other.clone()], vec![other.d(), self.d()])
}
fn powf(&self, exp: f64) -> Val {
Val::from(self.d().powf(exp), vec![self.clone()], vec![exp * self.d().powf(exp - 1.0)])
}
fn log(&self) -> Val { Val::from(self.d().ln(), vec![self.clone()], vec![1.0 / self.d()]) }
fn exp(&self) -> Val { let e = self.d().exp(); Val::from(e, vec![self.clone()], vec![e]) }
fn relu(&self) -> Val {
let d = self.d();
Val::from(d.max(0.0), vec![self.clone()], vec![if d > 0.0 { 1.0 } else { 0.0 }])
}
fn neg(&self) -> Val { self.muls(-1.0) }
fn div(&self, other: &Val) -> Val { self.mul(&other.powf(-1.0)) }
fn muls(&self, s: f64) -> Val { self.mul(&Val::new(s)) }
fn adds(&self, s: f64) -> Val { self.add(&Val::new(s)) }
fn backward(&self) {
let mut topo: Vec<Val> = Vec::new();
let mut visited: HashSet<usize> = HashSet::new();
let mut stack: Vec<(Val, bool)> = vec![(self.clone(), false)];
while let Some((v, processed)) = stack.pop() {
if processed { topo.push(v); continue; }
let ptr = Rc::as_ptr(&v.0) as usize;
if !visited.insert(ptr) { continue; }
stack.push((v.clone(), true));
for child in v.0.borrow().children.iter().rev() {
if !visited.contains(&(Rc::as_ptr(&child.0) as usize)) {
stack.push((child.clone(), false));
}
}
}
self.0.borrow_mut().grad = 1.0;
for v in topo.iter().rev() {
let (children, local_grads, vg) = {
let inner = v.0.borrow();
(inner.children.clone(), inner.local_grads.clone(), inner.grad)
};
for (child, lg) in children.iter().zip(local_grads.iter()) {
child.0.borrow_mut().grad += lg * vg;
}
}
}
}
// ---- Neural network building blocks ----
type Mat = Vec<Vec<Val>>;
fn init_matrix(rng: &mut Rng, rows: usize, cols: usize, std: f64) -> Mat {
(0..rows).map(|_| (0..cols).map(|_| Val::new(rng.gauss(0.0, std))).collect()).collect()
}
fn linear(x: &[Val], w: &Mat) -> Vec<Val> {
w.iter().map(|row| {
let mut s = row[0].mul(&x[0]);
for i in 1..row.len() { s = s.add(&row[i].mul(&x[i])); }
s
}).collect()
}
fn softmax(logits: &[Val]) -> Vec<Val> {
let max_v = logits.iter().map(|v| v.d()).fold(f64::NEG_INFINITY, f64::max);
let exps: Vec<Val> = logits.iter().map(|v| v.adds(-max_v).exp()).collect();
let total = exps[1..].iter().fold(exps[0].clone(), |a, e| a.add(e));
exps.iter().map(|e| e.div(&total)).collect()
}
fn rmsnorm(x: &[Val]) -> Vec<Val> {
let n = x.len() as f64;
let mut ms = x[0].mul(&x[0]);
for i in 1..x.len() { ms = ms.add(&x[i].mul(&x[i])); }
let scale = ms.muls(1.0 / n).adds(1e-5).powf(-0.5);
x.iter().map(|xi| xi.mul(&scale)).collect()
}
// ---- Transformer layer weights ----
struct Layer { attn_wq: Mat, attn_wk: Mat, attn_wv: Mat, attn_wo: Mat, mlp_fc1: Mat, mlp_fc2: Mat }
// ---- GPT forward pass (single token, KV-cache style) ----
fn gpt(
token_id: usize, pos_id: usize,
keys: &mut [Vec<Vec<Val>>], values: &mut [Vec<Vec<Val>>],
wte: &Mat, wpe: &Mat, lm_head: &Mat, layers: &[Layer],
n_head: usize, head_dim: usize,
) -> Vec<Val> {
let mut x: Vec<Val> = wte[token_id].iter().zip(&wpe[pos_id]).map(|(t, p)| t.add(p)).collect();
x = rmsnorm(&x);
for (li, layer) in layers.iter().enumerate() {
// 1) Multi-head attention block
let xr = x.clone();
x = rmsnorm(&x);
let q = linear(&x, &layer.attn_wq);
let k = linear(&x, &layer.attn_wk);
let v = linear(&x, &layer.attn_wv);
keys[li].push(k);
values[li].push(v);
let mut x_attn = Vec::with_capacity(n_head * head_dim);
let scale = (head_dim as f64).sqrt();
for h in 0..n_head {
let hs = h * head_dim;
let q_h = &q[hs..hs + head_dim];
let al: Vec<Val> = (0..keys[li].len()).map(|t| {
let mut dot = q_h[0].mul(&keys[li][t][hs]);
for j in 1..head_dim { dot = dot.add(&q_h[j].mul(&keys[li][t][hs + j])); }
dot.muls(1.0 / scale)
}).collect();
let aw = softmax(&al);
for j in 0..head_dim {
let mut s = aw[0].mul(&values[li][0][hs + j]);
for t in 1..aw.len() { s = s.add(&aw[t].mul(&values[li][t][hs + j])); }
x_attn.push(s);
}
}
x = linear(&x_attn, &layer.attn_wo);
x = x.iter().zip(&xr).map(|(a, b)| a.add(b)).collect();
// 2) MLP block
let xr = x.clone();
x = rmsnorm(&x);
x = linear(&x, &layer.mlp_fc1);
x = x.iter().map(|xi| xi.relu()).collect();
x = linear(&x, &layer.mlp_fc2);
x = x.iter().zip(&xr).map(|(a, b)| a.add(b)).collect();
}
linear(&x, lm_head)
}
// ---- Main: load data, train, and generate ----
fn main() {
let mut rng = Rng::new(42); // Let there be order among chaos
// Let there be an input dataset
let content = fs::read_to_string("input.txt").expect("input.txt not found");
let mut docs: Vec<String> = content.lines().map(|l| l.trim().to_string()).filter(|l| !l.is_empty()).collect();
rng.shuffle(&mut docs);
println!("num docs: {}", docs.len());
// Let there be a Tokenizer
let mut uchars: Vec<char> = docs.iter().flat_map(|d| d.chars()).collect();
uchars.sort();
uchars.dedup();
let bos = uchars.len();
let vocab_size = uchars.len() + 1;
println!("vocab size: {}", vocab_size);
// Hyperparameters
let (n_embd, n_head, n_layer, block_size) = (16, 4, 1, 16);
let head_dim = n_embd / n_head;
// Initialize the parameters
let wte = init_matrix(&mut rng, vocab_size, n_embd, 0.08);
let wpe = init_matrix(&mut rng, block_size, n_embd, 0.08);
let lm_head = init_matrix(&mut rng, vocab_size, n_embd, 0.08);
let layers: Vec<Layer> = (0..n_layer).map(|_| Layer {
attn_wq: init_matrix(&mut rng, n_embd, n_embd, 0.08),
attn_wk: init_matrix(&mut rng, n_embd, n_embd, 0.08),
attn_wv: init_matrix(&mut rng, n_embd, n_embd, 0.08),
attn_wo: init_matrix(&mut rng, n_embd, n_embd, 0.08),
mlp_fc1: init_matrix(&mut rng, 4 * n_embd, n_embd, 0.08),
mlp_fc2: init_matrix(&mut rng, n_embd, 4 * n_embd, 0.08),
}).collect();
// Flatten params into a single list
let mut params: Vec<Val> = Vec::new();
for mat in [&wte, &wpe, &lm_head] {
for row in mat { for p in row { params.push(p.clone()); } }
}
for layer in &layers {
for mat in [&layer.attn_wq, &layer.attn_wk, &layer.attn_wv,
&layer.attn_wo, &layer.mlp_fc1, &layer.mlp_fc2] {
for row in mat { for p in row { params.push(p.clone()); } }
}
}
println!("num params: {}", params.len());
// Let there be Adam, the blessed optimizer and its buffers
let (lr, beta1, beta2, eps) = (0.01_f64, 0.85_f64, 0.99_f64, 1e-8_f64);
let mut m_buf = vec![0.0f64; params.len()];
let mut v_buf = vec![0.0f64; params.len()];
// Repeat in sequence
let num_steps = 1000;
for step in 0..num_steps {
let doc = &docs[step % docs.len()];
let mut tokens: Vec<usize> = vec![bos];
for ch in doc.chars() { tokens.push(uchars.iter().position(|&c| c == ch).unwrap()); }
tokens.push(bos);
let n = block_size.min(tokens.len() - 1);
let mut keys: Vec<Vec<Vec<Val>>> = (0..n_layer).map(|_| Vec::new()).collect();
let mut vals: Vec<Vec<Vec<Val>>> = (0..n_layer).map(|_| Vec::new()).collect();
let mut losses: Vec<Val> = Vec::new();
for pos in 0..n {
let logits = gpt(tokens[pos], pos, &mut keys, &mut vals, &wte, &wpe, &lm_head, &layers, n_head, head_dim);
let probs = softmax(&logits);
losses.push(probs[tokens[pos + 1]].log().neg());
}
let mut loss = losses[0].clone();
for i in 1..losses.len() { loss = loss.add(&losses[i]); }
loss = loss.muls(1.0 / n as f64);
loss.backward();
let lr_t = lr * (1.0 - step as f64 / num_steps as f64);
for (i, p) in params.iter().enumerate() {
let g = p.g();
m_buf[i] = beta1 * m_buf[i] + (1.0 - beta1) * g;
v_buf[i] = beta2 * v_buf[i] + (1.0 - beta2) * g * g;
let m_hat = m_buf[i] / (1.0 - beta1.powi((step + 1) as i32));
let v_hat = v_buf[i] / (1.0 - beta2.powi((step + 1) as i32));
p.set_d(p.d() - lr_t * m_hat / (v_hat.sqrt() + eps));
p.set_g(0.0);
}
println!("step {:4} / {:4} | loss {:.4}", step + 1, num_steps, loss.d());
}
// Inference: may the model babble back to us
let temperature = 0.5_f64;
println!("\n--- inference (new, hallucinated names) ---");
for si in 0..20 {
let mut keys: Vec<Vec<Vec<Val>>> = (0..n_layer).map(|_| Vec::new()).collect();
let mut vals: Vec<Vec<Vec<Val>>> = (0..n_layer).map(|_| Vec::new()).collect();
let mut tok = bos;
let mut name = String::new();
for pos in 0..block_size {
let logits = gpt(tok, pos, &mut keys, &mut vals, &wte, &wpe, &lm_head, &layers, n_head, head_dim);
let scaled: Vec<Val> = logits.iter().map(|l| l.muls(1.0 / temperature)).collect();
let probs = softmax(&scaled);
let weights: Vec<f64> = probs.iter().map(|p| p.d()).collect();
tok = rng.weighted_choice(&weights);
if tok == bos { break; }
name.push(uchars[tok]);
}
println!("sample {:2}: {}", si + 1, name);
}
}
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment