Created
February 13, 2026 20:40
-
-
Save behrica/8971efff0169b1116a733d59eaa4a35b to your computer and use it in GitHub Desktop.
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
| (ns microgpt.core | |
| (:gen-class)) | |
| ;; A minimal, dependency-free translation of the provided Python `microgpt.py` into Clojure. | |
| ;; This keeps the same algorithmic structure: a tiny autograd Value type, a GPT-like | |
| ;; forward function, Adam updates, and a simple training loop. This is intended for | |
| ;; educational purposes and may be slow for real training. | |
| ;; Utilities and deterministic RNG | |
| (def rng (java.util.Random. 42)) | |
| (defn rand-gauss [] (.nextGaussian rng)) | |
| (defn rand-double [] (.nextDouble rng)) | |
| (defn rand-int [n] (.nextInt rng n)) | |
| ;; Value representation: mutable atoms for :data and :grad so we can update params | |
| (defn make-value | |
| ([x] (make-value x [] [])) | |
| ([x children local-grads] | |
| {:id (java.util.UUID/randomUUID) | |
| :data (atom (double x)) | |
| :grad (atom 0.0) | |
| :children (vec children) | |
| :local-grads (vec local-grads)})) | |
| (defn ensure-value [x] | |
| (cond | |
| (map? x) x | |
| (number? x) (make-value x) | |
| :else (throw (ex-info "Unsupported value for ensure-value" {:val x})))) | |
| ;; Basic ops that create new Value nodes and store local grads (numbers computed at forward time) | |
| (defn val-data [v] (double @(:data v))) | |
| (defn v-add [a b] | |
| (let [a (ensure-value a) b (ensure-value b) | |
| d (+ (val-data a) (val-data b))] | |
| (make-value d [a b] [1.0 1.0]))) | |
| (defn v-mul [a b] | |
| (let [a (ensure-value a) b (ensure-value b) | |
| da (val-data a) db (val-data b) | |
| d (* da db)] | |
| (make-value d [a b] [db da]))) | |
| (defn v-neg [a] | |
| (let [a (ensure-value a)] | |
| (make-value (- (val-data a)) [a] [-1.0]))) | |
| (defn v-sub [a b] | |
| (v-add (ensure-value a) (v-neg (ensure-value b)))) | |
| (defn v-pow [a e] | |
| (let [a (ensure-value a) da (val-data a) | |
| d (Math/pow da e)] | |
| (make-value d [a] [(* e (Math/pow da (dec e)))]))) | |
| (defn v-log [a] | |
| (let [a (ensure-value a) da (val-data a)] | |
| (make-value (Math/log da) [a] [(/ 1.0 da)]))) | |
| (defn v-exp [a] | |
| (let [a (ensure-value a) da (val-data a) e (Math/exp da)] | |
| (make-value e [a] [e]))) | |
| (defn v-relu [a] | |
| (let [a (ensure-value a) da (val-data a)] | |
| (make-value (max 0.0 da) [a] [(if (> da 0.0) 1.0 0.0)]))) | |
| (defn v-div [a b] | |
| (let [b (ensure-value b)] | |
| (v-mul (ensure-value a) (v-pow b -1)))) | |
| ;; Backpropagation: build topo ordering then propagate gradients | |
| (defn backward [root] | |
| (let [topo (transient []) | |
| visited (atom #{})] | |
| (letfn [(build-topo [v] | |
| (when-not (contains? @visited (:id v)) | |
| (swap! visited conj (:id v)) | |
| (doseq [c (:children v)] (build-topo c)) | |
| (conj! topo v)))] | |
| (build-topo root) | |
| (reset! (:grad root) 1.0) | |
| (let [topo (persistent! topo)] | |
| (doseq [v (reverse topo)] | |
| (let [g @(:grad v)] | |
| (dotimes [i (count (:children v))] | |
| (let [child (nth (:children v) i) | |
| local (nth (:local-grads v) i)] | |
| (swap! (:grad child) + (* local g)))))))))) | |
| ;; Helpers for creating matrices/vectors of params | |
| (defn matrix [nout nin & {:keys [std] :or {std 0.08}}] | |
| (vec (for [_ (range nout)] | |
| (vec (for [_ (range nin)] (make-value (* (double std) (rand-gauss)))))))) | |
| ;; Linear: x is vector of Value, w is matrix (rows of weights as Value) | |
| (defn linear [x w] | |
| (mapv (fn [wo] | |
| (reduce (fn [acc [wi xi]] (v-add acc (v-mul wi xi))) (make-value 0.0) (map vector wo x))) | |
| w)) | |
| ;; softmax over vector of Value | |
| (defn softmax [logits] | |
| (let [maxv (apply max (map val-data logits)) | |
| exps (mapv (fn [v] (v-exp (v-sub v (make-value maxv)))) logits) | |
| total (reduce (fn [a b] (v-add a b)) (make-value 0.0) exps)] | |
| (mapv (fn [e] (v-div e total)) exps))) | |
| ;; rmsnorm: scale vector of Value by computed scalar | |
| (defn rmsnorm [x] | |
| (let [len (count x) | |
| ms (->> x (mapv (fn [xi] (v-mul xi xi))) (reduce (fn [a b] (v-add a b)) (make-value 0.0))) | |
| ms-val (/ (val-data ms) len) | |
| scale (Math/pow (+ ms-val 1e-5) -0.5)] | |
| (mapv (fn [xi] (v-mul xi (make-value scale))) x))) | |
| ;; Model hyperparameters (same defaults as Python) | |
| (def n-embd 16) | |
| (def n-head 4) | |
| (def n-layer 1) | |
| (def block-size 16) | |
| (def head-dim (quot n-embd n-head)) | |
| ;; Load dataset or fetch remote if missing | |
| (defn ensure-input-file [] | |
| (let [f (java.io.File. "input.txt")] | |
| (when-not (.exists f) | |
| (let [url (java.net.URL. "https://raw.githubusercontent.com/karpathy/makemore/refs/heads/master/names.txt")] | |
| (with-open [in (.openStream url) | |
| out (java.io.FileOutputStream. f)] | |
| (clojure.java.io/copy in out)))))) | |
| (defn read-docs [] | |
| (ensure-input-file) | |
| (->> (slurp "input.txt") | |
| clojure.string/split-lines | |
| (map clojure.string/trim) | |
| (filter seq) | |
| (vec))) | |
| (defn build-vocab [docs] | |
| (let [chars (->> docs (apply str) set sort vec) | |
| uchars chars | |
| bos (count uchars) | |
| vocab-size (inc (count uchars))] | |
| {:uchars uchars :BOS bos :vocab-size vocab-size})) | |
| ;; Build initial state dict | |
| (defn init-state [vocab-size] | |
| (let [sd (atom {})] | |
| (swap! sd assoc :wte (matrix vocab-size n-embd)) | |
| (swap! sd assoc :wpe (matrix block-size n-embd)) | |
| (swap! sd assoc :lm_head (matrix vocab-size n-embd)) | |
| (dotimes [i n-layer] | |
| (swap! sd assoc (keyword (format "layer%d.attn_wq" i)) (matrix n-embd n-embd)) | |
| (swap! sd assoc (keyword (format "layer%d.attn_wk" i)) (matrix n-embd n-embd)) | |
| (swap! sd assoc (keyword (format "layer%d.attn_wv" i)) (matrix n-embd n-embd)) | |
| (swap! sd assoc (keyword (format "layer%d.attn_wo" i)) (matrix n-embd n-embd)) | |
| (swap! sd assoc (keyword (format "layer%d.mlp_fc1" i)) (matrix (* 4 n-embd) n-embd)) | |
| (swap! sd assoc (keyword (format "layer%d.mlp_fc2" i)) (matrix n-embd (* 4 n-embd)))) | |
| @sd)) | |
| (defn flatten-params [state-dict] | |
| (vec (apply concat (for [[_ mat] state-dict] | |
| (for [row mat] | |
| (for [p row] p)))))) | |
| ;; GPT forward (stateless function, returns vector of logits (Value)) | |
| (defn gpt [state-dict token-id pos-id keys values] | |
| (let [tok-emb (get (:wte state-dict) token-id) | |
| pos-emb (get (:wpe state-dict) pos-id) | |
| x (mapv v-add tok-emb pos-emb) | |
| x (rmsnorm x)] | |
| (loop [li 0 x x keys keys values values] | |
| (if (>= li n-layer) | |
| (linear x (get state-dict :lm_head)) | |
| (let [x-res x | |
| x (rmsnorm x) | |
| wq (get state-dict (keyword (format "layer%d.attn_wq" li))) | |
| wk (get state-dict (keyword (format "layer%d.attn_wk" li))) | |
| wv (get state-dict (keyword (format "layer%d.attn_wv" li))) | |
| wo (get state-dict (keyword (format "layer%d.attn_wo" li))) | |
| q (linear x wq) | |
| k (linear x wk) | |
| v (linear x wv) | |
| keys (update keys li conj k) | |
| values (update values li conj v) | |
| x-attn (transient [])] | |
| (doseq [h (range n-head)] | |
| (let [hs (* h head-dim) | |
| q-h (subvec (vec q) hs (+ hs head-dim)) | |
| k-h (mapv #(subvec (vec %) hs (+ hs head-dim)) (get keys li)) | |
| v-h (mapv #(subvec (vec %) hs (+ hs head-dim)) (get values li)) | |
| attn-logits (mapv (fn [ki] | |
| (reduce + (map-indexed (fn [j qj] (* (val-data qj) (val-data (nth ki j)))) q-h))) | |
| k-h) | |
| attn-logits (mapv (fn [x] (/ x (Math/sqrt head-dim))) attn-logits) | |
| attn-weights (softmax (mapv make-value attn-logits)) | |
| head-out (mapv (fn [j] | |
| (reduce (fn [acc t] | |
| (v-add acc (v-mul (nth attn-weights t) (nth (nth v-h t) j)))) | |
| (make-value 0.0) | |
| (range (count v-h)))) | |
| (range head-dim))] | |
| (doseq [ho head-out] (conj! x-attn ho)))) | |
| (let [x-attn (persistent! x-attn) | |
| x (linear x-attn wo) | |
| x (mapv (fn [a b] (v-add a b)) x x-res) | |
| ;; MLP | |
| x-res x | |
| x (rmsnorm x) | |
| fc1 (linear x (get state-dict (keyword (format "layer%d.mlp_fc1" li)))) | |
| fc1 (mapv v-relu fc1) | |
| fc2 (linear fc1 (get state-dict (keyword (format "layer%d.mlp_fc2" li)))) | |
| x (mapv (fn [a b] (v-add a b)) fc2 x-res)] | |
| (recur (inc li) x keys values))))))) | |
| ;; Weighted choice for sampling | |
| (defn weighted-choice [weights] | |
| (let [total (reduce + weights) | |
| r (* (rand-double) total)] | |
| (loop [i 0 acc 0.0] | |
| (let [acc (+ acc (nth weights i))] | |
| (if (> acc r) i (recur (inc i) acc)))))) | |
| ;; Main entry: training and sampling | |
| (defn -main [& _args] | |
| (let [docs (read-docs)] | |
| (println (format "num docs: %d" (count docs))) | |
| (let [{:keys [uchars BOS vocab-size]} (build-vocab docs) | |
| _ (println (format "vocab size: %d" vocab-size)) | |
| state-dict (init-state vocab-size) | |
| ;; flatten params | |
| params (vec (mapcat identity (flatten-params state-dict))) | |
| ;; Adam buffers | |
| learning-rate 0.01 | |
| beta1 0.85 | |
| beta2 0.99 | |
| eps-adam 1e-8 | |
| m (double-array (count params)) | |
| v (double-array (count params)) | |
| num-steps 100] | |
| (println (format "num params: %d" (count params))) | |
| (doseq [step (range num-steps)] | |
| ;; pick doc and tokenize | |
| (let [doc (nth docs (mod step (count docs))) | |
| tokens (into [(int BOS)] (map #(int (.indexOf uchars %)) doc)) | |
| tokens (conj tokens BOS) | |
| n (min block-size (dec (count tokens))) | |
| keys (vec (repeat n-layer [])) | |
| values (vec (repeat n-layer [])) | |
| losses (transient [])] | |
| (dotimes [pos-id n] | |
| (let [token-id (nth tokens pos-id) | |
| target-id (nth tokens (inc pos-id)) | |
| logits (gpt state-dict token-id pos-id keys values) | |
| probs (softmax logits) | |
| loss-t (v-neg (v-log (nth probs target-id)))] | |
| (conj! losses loss-t))) | |
| (let [loss (v-div (reduce (fn [a b] (v-add a b)) (make-value 0.0) (persistent! losses)) (make-value (double n)))] | |
| (backward loss) | |
| ;; Adam update | |
| (let [lr-t (* learning-rate (- 1.0 (/ step num-steps)))] | |
| (dotimes [i (count params)] | |
| (let [p (nth params i) | |
| g @(:grad p)] | |
| (aset-double m i (+ (* beta1 (aget m i)) (* (- 1 beta1) g))) | |
| (aset-double v i (+ (* beta2 (aget v i)) (* (- 1 beta2) (* g g)))) | |
| (let [m-hat (/ (aget m i) (- 1.0 (Math/pow beta1 (inc step)))) | |
| v-hat (/ (aget v i) (- 1.0 (Math/pow beta2 (inc step)))) | |
| update (/ (* lr-t m-hat) (+ (Math/sqrt v-hat) eps-adam))] | |
| (swap! (:data p) - update) | |
| (reset! (:grad p) 0.0))))) | |
| (println (format "step %4d / %4d | loss %.4f" (inc step) num-steps (val-data loss)))))) | |
| ;; inference | |
| (let [temperature 0.5] | |
| (println "\n--- inference (new, hallucinated names) ---") | |
| (dotimes [sample-idx 20] | |
| (let [keys (vec (repeat n-layer [])) | |
| values (vec (repeat n-layer [])) | |
| token-id BOS | |
| sb (StringBuilder.)] | |
| (loop [pos-id 0 token-id token-id] | |
| (when (< pos-id block-size) | |
| (let [logits (gpt state-dict token-id pos-id keys values) | |
| scaled (mapv (fn [l] (make-value (* (/ (val-data l) temperature) 1.0))) logits) | |
| probs (softmax scaled) | |
| weights (mapv (fn [p] (val-data p)) probs) | |
| token-id (weighted-choice weights)] | |
| (when (not= token-id BOS) | |
| (.append sb (nth uchars token-id))) | |
| (if (= token-id BOS) | |
| (println (format "sample %2d: %s" (inc sample-idx) (.toString sb))) | |
| (recur (inc pos-id) token-id))))))))))) | |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment