-
-
Save chaliy/677b9feb95ca7b655790fbfcbc17c5dd to your computer and use it in GitHub Desktop.
microgpt
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
| /// The most atomic way to train and inference a GPT in pure, dependency-free Rust. | |
| /// This file is the complete algorithm. Everything else is just efficiency. | |
| /// | |
| /// Ported from @karpathy's microgpt.py | |
| /// | |
| /// Compile: rustc -O microgpt.rs | |
| /// Run: ./microgpt (expects input.txt in current directory) | |
| use std::cell::RefCell; | |
| use std::collections::HashSet; | |
| use std::fs; | |
| use std::rc::Rc; | |
| // ---- RNG (xorshift64 + Box-Muller + Fisher-Yates) ---- | |
| // Can be replaced with `rand` + `rand_distr` crates via Cargo. | |
| struct Rng { state: u64 } | |
| impl Rng { | |
| fn new(seed: u64) -> Self { Self { state: if seed == 0 { 1 } else { seed } } } | |
| fn next_u64(&mut self) -> u64 { | |
| self.state ^= self.state << 13; | |
| self.state ^= self.state >> 7; | |
| self.state ^= self.state << 17; | |
| self.state | |
| } | |
| fn next_f64(&mut self) -> f64 { (self.next_u64() >> 11) as f64 / ((1u64 << 53) as f64) } | |
| fn gauss(&mut self, mean: f64, std: f64) -> f64 { | |
| let u1 = self.next_f64().max(1e-15); | |
| let u2 = self.next_f64(); | |
| mean + std * (-2.0 * u1.ln()).sqrt() * (2.0 * std::f64::consts::PI * u2).cos() | |
| } | |
| fn shuffle<T>(&mut self, v: &mut [T]) { | |
| for i in (1..v.len()).rev() { v.swap(i, (self.next_u64() as usize) % (i + 1)); } | |
| } | |
| fn weighted_choice(&mut self, weights: &[f64]) -> usize { | |
| let mut r = self.next_f64() * weights.iter().sum::<f64>(); | |
| for (i, &w) in weights.iter().enumerate() { r -= w; if r <= 0.0 { return i; } } | |
| weights.len() - 1 | |
| } | |
| } | |
| // ---- Autograd: Value node with computation graph for automatic differentiation ---- | |
| struct ValueInner { data: f64, grad: f64, children: Vec<Val>, local_grads: Vec<f64> } | |
| #[derive(Clone)] | |
| struct Val(Rc<RefCell<ValueInner>>); | |
| impl Val { | |
| fn new(data: f64) -> Self { | |
| Self(Rc::new(RefCell::new(ValueInner { data, grad: 0.0, children: vec![], local_grads: vec![] }))) | |
| } | |
| fn from(data: f64, children: Vec<Val>, local_grads: Vec<f64>) -> Self { | |
| Self(Rc::new(RefCell::new(ValueInner { data, grad: 0.0, children, local_grads }))) | |
| } | |
| fn d(&self) -> f64 { self.0.borrow().data } | |
| fn g(&self) -> f64 { self.0.borrow().grad } | |
| fn set_d(&self, v: f64) { self.0.borrow_mut().data = v; } | |
| fn set_g(&self, v: f64) { self.0.borrow_mut().grad = v; } | |
| fn add(&self, other: &Val) -> Val { | |
| Val::from(self.d() + other.d(), vec![self.clone(), other.clone()], vec![1.0, 1.0]) | |
| } | |
| fn mul(&self, other: &Val) -> Val { | |
| Val::from(self.d() * other.d(), vec![self.clone(), other.clone()], vec![other.d(), self.d()]) | |
| } | |
| fn powf(&self, exp: f64) -> Val { | |
| Val::from(self.d().powf(exp), vec![self.clone()], vec![exp * self.d().powf(exp - 1.0)]) | |
| } | |
| fn log(&self) -> Val { Val::from(self.d().ln(), vec![self.clone()], vec![1.0 / self.d()]) } | |
| fn exp(&self) -> Val { let e = self.d().exp(); Val::from(e, vec![self.clone()], vec![e]) } | |
| fn relu(&self) -> Val { | |
| let d = self.d(); | |
| Val::from(d.max(0.0), vec![self.clone()], vec![if d > 0.0 { 1.0 } else { 0.0 }]) | |
| } | |
| fn neg(&self) -> Val { self.muls(-1.0) } | |
| fn div(&self, other: &Val) -> Val { self.mul(&other.powf(-1.0)) } | |
| fn muls(&self, s: f64) -> Val { self.mul(&Val::new(s)) } | |
| fn adds(&self, s: f64) -> Val { self.add(&Val::new(s)) } | |
| fn backward(&self) { | |
| let mut topo: Vec<Val> = Vec::new(); | |
| let mut visited: HashSet<usize> = HashSet::new(); | |
| let mut stack: Vec<(Val, bool)> = vec![(self.clone(), false)]; | |
| while let Some((v, processed)) = stack.pop() { | |
| if processed { topo.push(v); continue; } | |
| let ptr = Rc::as_ptr(&v.0) as usize; | |
| if !visited.insert(ptr) { continue; } | |
| stack.push((v.clone(), true)); | |
| for child in v.0.borrow().children.iter().rev() { | |
| if !visited.contains(&(Rc::as_ptr(&child.0) as usize)) { | |
| stack.push((child.clone(), false)); | |
| } | |
| } | |
| } | |
| self.0.borrow_mut().grad = 1.0; | |
| for v in topo.iter().rev() { | |
| let (children, local_grads, vg) = { | |
| let inner = v.0.borrow(); | |
| (inner.children.clone(), inner.local_grads.clone(), inner.grad) | |
| }; | |
| for (child, lg) in children.iter().zip(local_grads.iter()) { | |
| child.0.borrow_mut().grad += lg * vg; | |
| } | |
| } | |
| } | |
| } | |
| // ---- Neural network building blocks ---- | |
| type Mat = Vec<Vec<Val>>; | |
| fn init_matrix(rng: &mut Rng, rows: usize, cols: usize, std: f64) -> Mat { | |
| (0..rows).map(|_| (0..cols).map(|_| Val::new(rng.gauss(0.0, std))).collect()).collect() | |
| } | |
| fn linear(x: &[Val], w: &Mat) -> Vec<Val> { | |
| w.iter().map(|row| { | |
| let mut s = row[0].mul(&x[0]); | |
| for i in 1..row.len() { s = s.add(&row[i].mul(&x[i])); } | |
| s | |
| }).collect() | |
| } | |
| fn softmax(logits: &[Val]) -> Vec<Val> { | |
| let max_v = logits.iter().map(|v| v.d()).fold(f64::NEG_INFINITY, f64::max); | |
| let exps: Vec<Val> = logits.iter().map(|v| v.adds(-max_v).exp()).collect(); | |
| let total = exps[1..].iter().fold(exps[0].clone(), |a, e| a.add(e)); | |
| exps.iter().map(|e| e.div(&total)).collect() | |
| } | |
| fn rmsnorm(x: &[Val]) -> Vec<Val> { | |
| let n = x.len() as f64; | |
| let mut ms = x[0].mul(&x[0]); | |
| for i in 1..x.len() { ms = ms.add(&x[i].mul(&x[i])); } | |
| let scale = ms.muls(1.0 / n).adds(1e-5).powf(-0.5); | |
| x.iter().map(|xi| xi.mul(&scale)).collect() | |
| } | |
| // ---- Transformer layer weights ---- | |
| struct Layer { attn_wq: Mat, attn_wk: Mat, attn_wv: Mat, attn_wo: Mat, mlp_fc1: Mat, mlp_fc2: Mat } | |
| // ---- GPT forward pass (single token, KV-cache style) ---- | |
| fn gpt( | |
| token_id: usize, pos_id: usize, | |
| keys: &mut [Vec<Vec<Val>>], values: &mut [Vec<Vec<Val>>], | |
| wte: &Mat, wpe: &Mat, lm_head: &Mat, layers: &[Layer], | |
| n_head: usize, head_dim: usize, | |
| ) -> Vec<Val> { | |
| let mut x: Vec<Val> = wte[token_id].iter().zip(&wpe[pos_id]).map(|(t, p)| t.add(p)).collect(); | |
| x = rmsnorm(&x); | |
| for (li, layer) in layers.iter().enumerate() { | |
| // 1) Multi-head attention block | |
| let xr = x.clone(); | |
| x = rmsnorm(&x); | |
| let q = linear(&x, &layer.attn_wq); | |
| let k = linear(&x, &layer.attn_wk); | |
| let v = linear(&x, &layer.attn_wv); | |
| keys[li].push(k); | |
| values[li].push(v); | |
| let mut x_attn = Vec::with_capacity(n_head * head_dim); | |
| let scale = (head_dim as f64).sqrt(); | |
| for h in 0..n_head { | |
| let hs = h * head_dim; | |
| let q_h = &q[hs..hs + head_dim]; | |
| let al: Vec<Val> = (0..keys[li].len()).map(|t| { | |
| let mut dot = q_h[0].mul(&keys[li][t][hs]); | |
| for j in 1..head_dim { dot = dot.add(&q_h[j].mul(&keys[li][t][hs + j])); } | |
| dot.muls(1.0 / scale) | |
| }).collect(); | |
| let aw = softmax(&al); | |
| for j in 0..head_dim { | |
| let mut s = aw[0].mul(&values[li][0][hs + j]); | |
| for t in 1..aw.len() { s = s.add(&aw[t].mul(&values[li][t][hs + j])); } | |
| x_attn.push(s); | |
| } | |
| } | |
| x = linear(&x_attn, &layer.attn_wo); | |
| x = x.iter().zip(&xr).map(|(a, b)| a.add(b)).collect(); | |
| // 2) MLP block | |
| let xr = x.clone(); | |
| x = rmsnorm(&x); | |
| x = linear(&x, &layer.mlp_fc1); | |
| x = x.iter().map(|xi| xi.relu()).collect(); | |
| x = linear(&x, &layer.mlp_fc2); | |
| x = x.iter().zip(&xr).map(|(a, b)| a.add(b)).collect(); | |
| } | |
| linear(&x, lm_head) | |
| } | |
| // ---- Main: load data, train, and generate ---- | |
| fn main() { | |
| let mut rng = Rng::new(42); // Let there be order among chaos | |
| // Let there be an input dataset | |
| let content = fs::read_to_string("input.txt").expect("input.txt not found"); | |
| let mut docs: Vec<String> = content.lines().map(|l| l.trim().to_string()).filter(|l| !l.is_empty()).collect(); | |
| rng.shuffle(&mut docs); | |
| println!("num docs: {}", docs.len()); | |
| // Let there be a Tokenizer | |
| let mut uchars: Vec<char> = docs.iter().flat_map(|d| d.chars()).collect(); | |
| uchars.sort(); | |
| uchars.dedup(); | |
| let bos = uchars.len(); | |
| let vocab_size = uchars.len() + 1; | |
| println!("vocab size: {}", vocab_size); | |
| // Hyperparameters | |
| let (n_embd, n_head, n_layer, block_size) = (16, 4, 1, 16); | |
| let head_dim = n_embd / n_head; | |
| // Initialize the parameters | |
| let wte = init_matrix(&mut rng, vocab_size, n_embd, 0.08); | |
| let wpe = init_matrix(&mut rng, block_size, n_embd, 0.08); | |
| let lm_head = init_matrix(&mut rng, vocab_size, n_embd, 0.08); | |
| let layers: Vec<Layer> = (0..n_layer).map(|_| Layer { | |
| attn_wq: init_matrix(&mut rng, n_embd, n_embd, 0.08), | |
| attn_wk: init_matrix(&mut rng, n_embd, n_embd, 0.08), | |
| attn_wv: init_matrix(&mut rng, n_embd, n_embd, 0.08), | |
| attn_wo: init_matrix(&mut rng, n_embd, n_embd, 0.08), | |
| mlp_fc1: init_matrix(&mut rng, 4 * n_embd, n_embd, 0.08), | |
| mlp_fc2: init_matrix(&mut rng, n_embd, 4 * n_embd, 0.08), | |
| }).collect(); | |
| // Flatten params into a single list | |
| let mut params: Vec<Val> = Vec::new(); | |
| for mat in [&wte, &wpe, &lm_head] { | |
| for row in mat { for p in row { params.push(p.clone()); } } | |
| } | |
| for layer in &layers { | |
| for mat in [&layer.attn_wq, &layer.attn_wk, &layer.attn_wv, | |
| &layer.attn_wo, &layer.mlp_fc1, &layer.mlp_fc2] { | |
| for row in mat { for p in row { params.push(p.clone()); } } | |
| } | |
| } | |
| println!("num params: {}", params.len()); | |
| // Let there be Adam, the blessed optimizer and its buffers | |
| let (lr, beta1, beta2, eps) = (0.01_f64, 0.85_f64, 0.99_f64, 1e-8_f64); | |
| let mut m_buf = vec![0.0f64; params.len()]; | |
| let mut v_buf = vec![0.0f64; params.len()]; | |
| // Repeat in sequence | |
| let num_steps = 1000; | |
| for step in 0..num_steps { | |
| let doc = &docs[step % docs.len()]; | |
| let mut tokens: Vec<usize> = vec![bos]; | |
| for ch in doc.chars() { tokens.push(uchars.iter().position(|&c| c == ch).unwrap()); } | |
| tokens.push(bos); | |
| let n = block_size.min(tokens.len() - 1); | |
| let mut keys: Vec<Vec<Vec<Val>>> = (0..n_layer).map(|_| Vec::new()).collect(); | |
| let mut vals: Vec<Vec<Vec<Val>>> = (0..n_layer).map(|_| Vec::new()).collect(); | |
| let mut losses: Vec<Val> = Vec::new(); | |
| for pos in 0..n { | |
| let logits = gpt(tokens[pos], pos, &mut keys, &mut vals, &wte, &wpe, &lm_head, &layers, n_head, head_dim); | |
| let probs = softmax(&logits); | |
| losses.push(probs[tokens[pos + 1]].log().neg()); | |
| } | |
| let mut loss = losses[0].clone(); | |
| for i in 1..losses.len() { loss = loss.add(&losses[i]); } | |
| loss = loss.muls(1.0 / n as f64); | |
| loss.backward(); | |
| let lr_t = lr * (1.0 - step as f64 / num_steps as f64); | |
| for (i, p) in params.iter().enumerate() { | |
| let g = p.g(); | |
| m_buf[i] = beta1 * m_buf[i] + (1.0 - beta1) * g; | |
| v_buf[i] = beta2 * v_buf[i] + (1.0 - beta2) * g * g; | |
| let m_hat = m_buf[i] / (1.0 - beta1.powi((step + 1) as i32)); | |
| let v_hat = v_buf[i] / (1.0 - beta2.powi((step + 1) as i32)); | |
| p.set_d(p.d() - lr_t * m_hat / (v_hat.sqrt() + eps)); | |
| p.set_g(0.0); | |
| } | |
| println!("step {:4} / {:4} | loss {:.4}", step + 1, num_steps, loss.d()); | |
| } | |
| // Inference: may the model babble back to us | |
| let temperature = 0.5_f64; | |
| println!("\n--- inference (new, hallucinated names) ---"); | |
| for si in 0..20 { | |
| let mut keys: Vec<Vec<Vec<Val>>> = (0..n_layer).map(|_| Vec::new()).collect(); | |
| let mut vals: Vec<Vec<Vec<Val>>> = (0..n_layer).map(|_| Vec::new()).collect(); | |
| let mut tok = bos; | |
| let mut name = String::new(); | |
| for pos in 0..block_size { | |
| let logits = gpt(tok, pos, &mut keys, &mut vals, &wte, &wpe, &lm_head, &layers, n_head, head_dim); | |
| let scaled: Vec<Val> = logits.iter().map(|l| l.muls(1.0 / temperature)).collect(); | |
| let probs = softmax(&scaled); | |
| let weights: Vec<f64> = probs.iter().map(|p| p.d()).collect(); | |
| tok = rng.weighted_choice(&weights); | |
| if tok == bos { break; } | |
| name.push(uchars[tok]); | |
| } | |
| println!("sample {:2}: {}", si + 1, name); | |
| } | |
| } |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment