Skip to content

Instantly share code, notes, and snippets.

@behrica
Created February 13, 2026 20:40
Show Gist options
  • Select an option

  • Save behrica/8971efff0169b1116a733d59eaa4a35b to your computer and use it in GitHub Desktop.

Select an option

Save behrica/8971efff0169b1116a733d59eaa4a35b to your computer and use it in GitHub Desktop.
(ns microgpt.core
(:gen-class))
;; A minimal, dependency-free translation of the provided Python `microgpt.py` into Clojure.
;; This keeps the same algorithmic structure: a tiny autograd Value type, a GPT-like
;; forward function, Adam updates, and a simple training loop. This is intended for
;; educational purposes and may be slow for real training.
;; Utilities and deterministic RNG
(def rng (java.util.Random. 42))
(defn rand-gauss [] (.nextGaussian rng))
(defn rand-double [] (.nextDouble rng))
(defn rand-int [n] (.nextInt rng n))
;; Value representation: mutable atoms for :data and :grad so we can update params
(defn make-value
([x] (make-value x [] []))
([x children local-grads]
{:id (java.util.UUID/randomUUID)
:data (atom (double x))
:grad (atom 0.0)
:children (vec children)
:local-grads (vec local-grads)}))
(defn ensure-value [x]
(cond
(map? x) x
(number? x) (make-value x)
:else (throw (ex-info "Unsupported value for ensure-value" {:val x}))))
;; Basic ops that create new Value nodes and store local grads (numbers computed at forward time)
(defn val-data [v] (double @(:data v)))
(defn v-add [a b]
(let [a (ensure-value a) b (ensure-value b)
d (+ (val-data a) (val-data b))]
(make-value d [a b] [1.0 1.0])))
(defn v-mul [a b]
(let [a (ensure-value a) b (ensure-value b)
da (val-data a) db (val-data b)
d (* da db)]
(make-value d [a b] [db da])))
(defn v-neg [a]
(let [a (ensure-value a)]
(make-value (- (val-data a)) [a] [-1.0])))
(defn v-sub [a b]
(v-add (ensure-value a) (v-neg (ensure-value b))))
(defn v-pow [a e]
(let [a (ensure-value a) da (val-data a)
d (Math/pow da e)]
(make-value d [a] [(* e (Math/pow da (dec e)))])))
(defn v-log [a]
(let [a (ensure-value a) da (val-data a)]
(make-value (Math/log da) [a] [(/ 1.0 da)])))
(defn v-exp [a]
(let [a (ensure-value a) da (val-data a) e (Math/exp da)]
(make-value e [a] [e])))
(defn v-relu [a]
(let [a (ensure-value a) da (val-data a)]
(make-value (max 0.0 da) [a] [(if (> da 0.0) 1.0 0.0)])))
(defn v-div [a b]
(let [b (ensure-value b)]
(v-mul (ensure-value a) (v-pow b -1))))
;; Backpropagation: build topo ordering then propagate gradients
(defn backward [root]
(let [topo (transient [])
visited (atom #{})]
(letfn [(build-topo [v]
(when-not (contains? @visited (:id v))
(swap! visited conj (:id v))
(doseq [c (:children v)] (build-topo c))
(conj! topo v)))]
(build-topo root)
(reset! (:grad root) 1.0)
(let [topo (persistent! topo)]
(doseq [v (reverse topo)]
(let [g @(:grad v)]
(dotimes [i (count (:children v))]
(let [child (nth (:children v) i)
local (nth (:local-grads v) i)]
(swap! (:grad child) + (* local g))))))))))
;; Helpers for creating matrices/vectors of params
(defn matrix [nout nin & {:keys [std] :or {std 0.08}}]
(vec (for [_ (range nout)]
(vec (for [_ (range nin)] (make-value (* (double std) (rand-gauss))))))))
;; Linear: x is vector of Value, w is matrix (rows of weights as Value)
(defn linear [x w]
(mapv (fn [wo]
(reduce (fn [acc [wi xi]] (v-add acc (v-mul wi xi))) (make-value 0.0) (map vector wo x)))
w))
;; softmax over vector of Value
(defn softmax [logits]
(let [maxv (apply max (map val-data logits))
exps (mapv (fn [v] (v-exp (v-sub v (make-value maxv)))) logits)
total (reduce (fn [a b] (v-add a b)) (make-value 0.0) exps)]
(mapv (fn [e] (v-div e total)) exps)))
;; rmsnorm: scale vector of Value by computed scalar
(defn rmsnorm [x]
(let [len (count x)
ms (->> x (mapv (fn [xi] (v-mul xi xi))) (reduce (fn [a b] (v-add a b)) (make-value 0.0)))
ms-val (/ (val-data ms) len)
scale (Math/pow (+ ms-val 1e-5) -0.5)]
(mapv (fn [xi] (v-mul xi (make-value scale))) x)))
;; Model hyperparameters (same defaults as Python)
(def n-embd 16)
(def n-head 4)
(def n-layer 1)
(def block-size 16)
(def head-dim (quot n-embd n-head))
;; Load dataset or fetch remote if missing
(defn ensure-input-file []
(let [f (java.io.File. "input.txt")]
(when-not (.exists f)
(let [url (java.net.URL. "https://raw.githubusercontent.com/karpathy/makemore/refs/heads/master/names.txt")]
(with-open [in (.openStream url)
out (java.io.FileOutputStream. f)]
(clojure.java.io/copy in out))))))
(defn read-docs []
(ensure-input-file)
(->> (slurp "input.txt")
clojure.string/split-lines
(map clojure.string/trim)
(filter seq)
(vec)))
(defn build-vocab [docs]
(let [chars (->> docs (apply str) set sort vec)
uchars chars
bos (count uchars)
vocab-size (inc (count uchars))]
{:uchars uchars :BOS bos :vocab-size vocab-size}))
;; Build initial state dict
(defn init-state [vocab-size]
(let [sd (atom {})]
(swap! sd assoc :wte (matrix vocab-size n-embd))
(swap! sd assoc :wpe (matrix block-size n-embd))
(swap! sd assoc :lm_head (matrix vocab-size n-embd))
(dotimes [i n-layer]
(swap! sd assoc (keyword (format "layer%d.attn_wq" i)) (matrix n-embd n-embd))
(swap! sd assoc (keyword (format "layer%d.attn_wk" i)) (matrix n-embd n-embd))
(swap! sd assoc (keyword (format "layer%d.attn_wv" i)) (matrix n-embd n-embd))
(swap! sd assoc (keyword (format "layer%d.attn_wo" i)) (matrix n-embd n-embd))
(swap! sd assoc (keyword (format "layer%d.mlp_fc1" i)) (matrix (* 4 n-embd) n-embd))
(swap! sd assoc (keyword (format "layer%d.mlp_fc2" i)) (matrix n-embd (* 4 n-embd))))
@sd))
(defn flatten-params [state-dict]
(vec (apply concat (for [[_ mat] state-dict]
(for [row mat]
(for [p row] p))))))
;; GPT forward (stateless function, returns vector of logits (Value))
(defn gpt [state-dict token-id pos-id keys values]
(let [tok-emb (get (:wte state-dict) token-id)
pos-emb (get (:wpe state-dict) pos-id)
x (mapv v-add tok-emb pos-emb)
x (rmsnorm x)]
(loop [li 0 x x keys keys values values]
(if (>= li n-layer)
(linear x (get state-dict :lm_head))
(let [x-res x
x (rmsnorm x)
wq (get state-dict (keyword (format "layer%d.attn_wq" li)))
wk (get state-dict (keyword (format "layer%d.attn_wk" li)))
wv (get state-dict (keyword (format "layer%d.attn_wv" li)))
wo (get state-dict (keyword (format "layer%d.attn_wo" li)))
q (linear x wq)
k (linear x wk)
v (linear x wv)
keys (update keys li conj k)
values (update values li conj v)
x-attn (transient [])]
(doseq [h (range n-head)]
(let [hs (* h head-dim)
q-h (subvec (vec q) hs (+ hs head-dim))
k-h (mapv #(subvec (vec %) hs (+ hs head-dim)) (get keys li))
v-h (mapv #(subvec (vec %) hs (+ hs head-dim)) (get values li))
attn-logits (mapv (fn [ki]
(reduce + (map-indexed (fn [j qj] (* (val-data qj) (val-data (nth ki j)))) q-h)))
k-h)
attn-logits (mapv (fn [x] (/ x (Math/sqrt head-dim))) attn-logits)
attn-weights (softmax (mapv make-value attn-logits))
head-out (mapv (fn [j]
(reduce (fn [acc t]
(v-add acc (v-mul (nth attn-weights t) (nth (nth v-h t) j))))
(make-value 0.0)
(range (count v-h))))
(range head-dim))]
(doseq [ho head-out] (conj! x-attn ho))))
(let [x-attn (persistent! x-attn)
x (linear x-attn wo)
x (mapv (fn [a b] (v-add a b)) x x-res)
;; MLP
x-res x
x (rmsnorm x)
fc1 (linear x (get state-dict (keyword (format "layer%d.mlp_fc1" li))))
fc1 (mapv v-relu fc1)
fc2 (linear fc1 (get state-dict (keyword (format "layer%d.mlp_fc2" li))))
x (mapv (fn [a b] (v-add a b)) fc2 x-res)]
(recur (inc li) x keys values)))))))
;; Weighted choice for sampling
(defn weighted-choice [weights]
(let [total (reduce + weights)
r (* (rand-double) total)]
(loop [i 0 acc 0.0]
(let [acc (+ acc (nth weights i))]
(if (> acc r) i (recur (inc i) acc))))))
;; Main entry: training and sampling
(defn -main [& _args]
(let [docs (read-docs)]
(println (format "num docs: %d" (count docs)))
(let [{:keys [uchars BOS vocab-size]} (build-vocab docs)
_ (println (format "vocab size: %d" vocab-size))
state-dict (init-state vocab-size)
;; flatten params
params (vec (mapcat identity (flatten-params state-dict)))
;; Adam buffers
learning-rate 0.01
beta1 0.85
beta2 0.99
eps-adam 1e-8
m (double-array (count params))
v (double-array (count params))
num-steps 100]
(println (format "num params: %d" (count params)))
(doseq [step (range num-steps)]
;; pick doc and tokenize
(let [doc (nth docs (mod step (count docs)))
tokens (into [(int BOS)] (map #(int (.indexOf uchars %)) doc))
tokens (conj tokens BOS)
n (min block-size (dec (count tokens)))
keys (vec (repeat n-layer []))
values (vec (repeat n-layer []))
losses (transient [])]
(dotimes [pos-id n]
(let [token-id (nth tokens pos-id)
target-id (nth tokens (inc pos-id))
logits (gpt state-dict token-id pos-id keys values)
probs (softmax logits)
loss-t (v-neg (v-log (nth probs target-id)))]
(conj! losses loss-t)))
(let [loss (v-div (reduce (fn [a b] (v-add a b)) (make-value 0.0) (persistent! losses)) (make-value (double n)))]
(backward loss)
;; Adam update
(let [lr-t (* learning-rate (- 1.0 (/ step num-steps)))]
(dotimes [i (count params)]
(let [p (nth params i)
g @(:grad p)]
(aset-double m i (+ (* beta1 (aget m i)) (* (- 1 beta1) g)))
(aset-double v i (+ (* beta2 (aget v i)) (* (- 1 beta2) (* g g))))
(let [m-hat (/ (aget m i) (- 1.0 (Math/pow beta1 (inc step))))
v-hat (/ (aget v i) (- 1.0 (Math/pow beta2 (inc step))))
update (/ (* lr-t m-hat) (+ (Math/sqrt v-hat) eps-adam))]
(swap! (:data p) - update)
(reset! (:grad p) 0.0)))))
(println (format "step %4d / %4d | loss %.4f" (inc step) num-steps (val-data loss))))))
;; inference
(let [temperature 0.5]
(println "\n--- inference (new, hallucinated names) ---")
(dotimes [sample-idx 20]
(let [keys (vec (repeat n-layer []))
values (vec (repeat n-layer []))
token-id BOS
sb (StringBuilder.)]
(loop [pos-id 0 token-id token-id]
(when (< pos-id block-size)
(let [logits (gpt state-dict token-id pos-id keys values)
scaled (mapv (fn [l] (make-value (* (/ (val-data l) temperature) 1.0))) logits)
probs (softmax scaled)
weights (mapv (fn [p] (val-data p)) probs)
token-id (weighted-choice weights)]
(when (not= token-id BOS)
(.append sb (nth uchars token-id)))
(if (= token-id BOS)
(println (format "sample %2d: %s" (inc sample-idx) (.toString sb)))
(recur (inc pos-id) token-id)))))))))))
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment