Last active
December 24, 2025 16:44
-
-
Save Mrutyunjay01/f0688f727f95e7ac7deba6ff0fb39dd3 to your computer and use it in GitHub Desktop.
Neural Networks in JAX from scratch.
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
| { | |
| "cells": [ | |
| { | |
| "cell_type": "code", | |
| "execution_count": 1, | |
| "id": "69b9d38d", | |
| "metadata": {}, | |
| "outputs": [], | |
| "source": [ | |
| "import jax\n", | |
| "from math import ceil\n", | |
| "from functools import partial\n", | |
| "from typing import NamedTuple, Any\n", | |
| "from jax import numpy as jnp\n", | |
| "from sklearn.datasets import fetch_openml\n", | |
| "from sklearn.model_selection import StratifiedKFold, train_test_split\n", | |
| "\n", | |
| "import matplotlib.pyplot as plt\n", | |
| "%matplotlib inline" | |
| ] | |
| }, | |
| { | |
| "cell_type": "code", | |
| "execution_count": 2, | |
| "id": "c94a8159", | |
| "metadata": {}, | |
| "outputs": [], | |
| "source": [ | |
| "fashion_mnist = fetch_openml(\"Fashion-MNIST\", cache=True)" | |
| ] | |
| }, | |
| { | |
| "cell_type": "code", | |
| "execution_count": 3, | |
| "id": "eb5b23f1", | |
| "metadata": {}, | |
| "outputs": [], | |
| "source": [ | |
| "X = jnp.array(fashion_mnist.data)\n", | |
| "y = jnp.array(fashion_mnist.target.astype(\"int\"))" | |
| ] | |
| }, | |
| { | |
| "cell_type": "code", | |
| "execution_count": 4, | |
| "id": "49821634", | |
| "metadata": {}, | |
| "outputs": [], | |
| "source": [ | |
| "class DataLoader:\n", | |
| " def __init__(self, X: jnp.array, y:jnp.array, batch_size: int):\n", | |
| " self.X = X\n", | |
| " self.y = y\n", | |
| " self.len_data = X.shape[0]\n", | |
| " assert self.len_data == y.shape[0], f\"mismatching samples between X {X.shape[0]} and y {y.shape[0]}\"\n", | |
| " self.batch_size = batch_size\n", | |
| " pass\n", | |
| "\n", | |
| " def get_data(self, is_train=True):\n", | |
| " # yield data in batches as per self.batch_size\n", | |
| " for idx in range(0, self.len_data, self.batch_size):\n", | |
| " features = self.X[idx: idx+self.batch_size]\n", | |
| " if features.shape[0] < self.batch_size:\n", | |
| " # pass on the last incomplete batch\n", | |
| " continue\n", | |
| " targets = self.y[idx: idx+self.batch_size]\n", | |
| " yield (features, targets) if is_train else features\n", | |
| " pass\n", | |
| " pass" | |
| ] | |
| }, | |
| { | |
| "cell_type": "code", | |
| "execution_count": 5, | |
| "id": "1510ed45", | |
| "metadata": {}, | |
| "outputs": [], | |
| "source": [ | |
| "data_loader = DataLoader(X, y, 8)\n", | |
| "train_data = data_loader.get_data(is_train=True)" | |
| ] | |
| }, | |
| { | |
| "cell_type": "code", | |
| "execution_count": 6, | |
| "id": "34187ab9", | |
| "metadata": {}, | |
| "outputs": [], | |
| "source": [ | |
| "def show_data(x: jnp.array, y: jnp.array):\n", | |
| " n_cols = 4\n", | |
| " n_rows = ceil(x.shape[0] / n_cols)\n", | |
| " fig, axes = plt.subplots(nrows=n_rows, ncols=n_cols)\n", | |
| "\n", | |
| " for row in range(n_rows):\n", | |
| " for col in range(n_cols):\n", | |
| " axes[row, col].imshow(x[row * col + col].reshape(28, 28))\n", | |
| " axes[row, col].set_xlabel(f\"class: {y[row * col + col]}\")\n", | |
| " plt.show()\n", | |
| " pass" | |
| ] | |
| }, | |
| { | |
| "cell_type": "code", | |
| "execution_count": 7, | |
| "id": "e64786a3", | |
| "metadata": {}, | |
| "outputs": [ | |
| { | |
| "data": { | |
| "image/png": "", | |
| "text/plain": [ | |
| "<Figure size 640x480 with 8 Axes>" | |
| ] | |
| }, | |
| "metadata": {}, | |
| "output_type": "display_data" | |
| } | |
| ], | |
| "source": [ | |
| "x_sample, y_sample = next(iter(train_data))\n", | |
| "show_data(x_sample, y_sample)" | |
| ] | |
| }, | |
| { | |
| "cell_type": "code", | |
| "execution_count": 8, | |
| "id": "2affd4be", | |
| "metadata": {}, | |
| "outputs": [], | |
| "source": [ | |
| "# how long does it take to get the complete dataset w.r.t. a batch size\n", | |
| "# import time\n", | |
| "\n", | |
| "# data_loader = DataLoader(X, y, batch_size=2048)\n", | |
| "# train_data = data_loader.get_data(is_train=True)\n", | |
| "# tic = time.time()\n", | |
| "# for x_train, y_train in train_data:\n", | |
| "# continue\n", | |
| "# print(f'{time.time() - tic:.2f}')" | |
| ] | |
| }, | |
| { | |
| "cell_type": "code", | |
| "execution_count": 9, | |
| "id": "091f25b7", | |
| "metadata": {}, | |
| "outputs": [ | |
| { | |
| "name": "stdout", | |
| "output_type": "stream", | |
| "text": [ | |
| "------------------ Fold 0 ------------------\n", | |
| "Iteration 0 Accuracy: 0.78, Loss 105.10041809082031\n", | |
| "Iteration 1 Accuracy: 0.84, Loss 50.06229782104492\n", | |
| "Iteration 2 Accuracy: 0.86, Loss 41.21639633178711\n", | |
| "Iteration 3 Accuracy: 0.87, Loss 36.25566864013672\n", | |
| "Iteration 4 Accuracy: 0.87, Loss 33.13346862792969\n", | |
| "Iteration 5 Accuracy: 0.86, Loss 31.37713050842285\n", | |
| "Iteration 6 Accuracy: 0.87, Loss 29.916488647460938\n", | |
| "Iteration 7 Accuracy: 0.86, Loss 28.327608108520508\n", | |
| "Iteration 8 Accuracy: 0.87, Loss 26.640146255493164\n", | |
| "Iteration 9 Accuracy: 0.86, Loss 25.484012603759766\n", | |
| "Iteration 10 Accuracy: 0.87, Loss 25.4431209564209\n", | |
| "Iteration 11 Accuracy: 0.88, Loss 25.435298919677734\n", | |
| "Iteration 12 Accuracy: 0.89, Loss 24.539390563964844\n", | |
| "Iteration 13 Accuracy: 0.89, Loss 23.322032928466797\n", | |
| "Iteration 14 Accuracy: 0.89, Loss 21.963422775268555\n", | |
| "Fold 0 Accuracy: 0.89\n", | |
| "------------------ Fold 1 ------------------\n", | |
| "Iteration 0 Accuracy: 0.77, Loss 111.32299041748047\n", | |
| "Iteration 1 Accuracy: 0.83, Loss 50.82963943481445\n", | |
| "Iteration 2 Accuracy: 0.84, Loss 39.62657165527344\n", | |
| "Iteration 3 Accuracy: 0.86, Loss 35.53031921386719\n", | |
| "Iteration 4 Accuracy: 0.87, Loss 32.69033432006836\n", | |
| "Iteration 5 Accuracy: 0.87, Loss 32.02552032470703\n", | |
| "Iteration 6 Accuracy: 0.87, Loss 29.951984405517578\n", | |
| "Iteration 7 Accuracy: 0.88, Loss 28.434228897094727\n", | |
| "Iteration 8 Accuracy: 0.88, Loss 27.313901901245117\n", | |
| "Iteration 9 Accuracy: 0.88, Loss 26.59328842163086\n", | |
| "Iteration 10 Accuracy: 0.88, Loss 26.086761474609375\n", | |
| "Iteration 11 Accuracy: 0.88, Loss 25.072847366333008\n", | |
| "Iteration 12 Accuracy: 0.89, Loss 23.68248748779297\n", | |
| "Iteration 13 Accuracy: 0.88, Loss 22.690715789794922\n", | |
| "Iteration 14 Accuracy: 0.89, Loss 22.063501358032227\n", | |
| "Fold 1 Accuracy: 0.89\n", | |
| "------------------ Fold 2 ------------------\n", | |
| "Iteration 0 Accuracy: 0.75, Loss 107.47624969482422\n", | |
| "Iteration 1 Accuracy: 0.84, Loss 51.53094482421875\n", | |
| "Iteration 2 Accuracy: 0.85, Loss 41.4818115234375\n", | |
| "Iteration 3 Accuracy: 0.84, Loss 36.80481719970703\n", | |
| "Iteration 4 Accuracy: 0.86, Loss 34.52952194213867\n", | |
| "Iteration 5 Accuracy: 0.87, Loss 31.991575241088867\n", | |
| "Iteration 6 Accuracy: 0.87, Loss 29.796144485473633\n", | |
| "Iteration 7 Accuracy: 0.87, Loss 28.153940200805664\n", | |
| "Iteration 8 Accuracy: 0.88, Loss 26.822463989257812\n", | |
| "Iteration 9 Accuracy: 0.88, Loss 26.207122802734375\n", | |
| "Iteration 10 Accuracy: 0.88, Loss 26.46046257019043\n", | |
| "Iteration 11 Accuracy: 0.88, Loss 24.957033157348633\n", | |
| "Iteration 12 Accuracy: 0.86, Loss 24.07745361328125\n", | |
| "Iteration 13 Accuracy: 0.87, Loss 23.486385345458984\n", | |
| "Iteration 14 Accuracy: 0.88, Loss 22.554162979125977\n", | |
| "Fold 2 Accuracy: 0.88\n", | |
| "------------------ Fold 3 ------------------\n", | |
| "Iteration 0 Accuracy: 0.76, Loss 110.51836395263672\n", | |
| "Iteration 1 Accuracy: 0.82, Loss 50.895751953125\n", | |
| "Iteration 2 Accuracy: 0.85, Loss 40.91993713378906\n", | |
| "Iteration 3 Accuracy: 0.87, Loss 36.93730926513672\n", | |
| "Iteration 4 Accuracy: 0.87, Loss 34.084014892578125\n", | |
| "Iteration 5 Accuracy: 0.85, Loss 32.147178649902344\n", | |
| "Iteration 6 Accuracy: 0.87, Loss 30.568044662475586\n", | |
| "Iteration 7 Accuracy: 0.87, Loss 30.234619140625\n", | |
| "Iteration 8 Accuracy: 0.87, Loss 28.12507438659668\n", | |
| "Iteration 9 Accuracy: 0.87, Loss 27.100788116455078\n", | |
| "Iteration 10 Accuracy: 0.88, Loss 26.548999786376953\n", | |
| "Iteration 11 Accuracy: 0.87, Loss 25.430347442626953\n", | |
| "Iteration 12 Accuracy: 0.87, Loss 25.04415512084961\n", | |
| "Iteration 13 Accuracy: 0.87, Loss 23.559974670410156\n", | |
| "Iteration 14 Accuracy: 0.87, Loss 23.09308433532715\n", | |
| "Fold 3 Accuracy: 0.87\n", | |
| "------------------ Fold 4 ------------------\n", | |
| "Iteration 0 Accuracy: 0.78, Loss 110.41956329345703\n", | |
| "Iteration 1 Accuracy: 0.82, Loss 48.65843200683594\n", | |
| "Iteration 2 Accuracy: 0.83, Loss 40.263362884521484\n", | |
| "Iteration 3 Accuracy: 0.86, Loss 36.066993713378906\n", | |
| "Iteration 4 Accuracy: 0.86, Loss 32.979530334472656\n", | |
| "Iteration 5 Accuracy: 0.86, Loss 30.87268829345703\n", | |
| "Iteration 6 Accuracy: 0.87, Loss 29.401487350463867\n", | |
| "Iteration 7 Accuracy: 0.87, Loss 28.30144500732422\n", | |
| "Iteration 8 Accuracy: 0.87, Loss 26.695728302001953\n", | |
| "Iteration 9 Accuracy: 0.87, Loss 25.73813819885254\n", | |
| "Iteration 10 Accuracy: 0.88, Loss 25.964645385742188\n", | |
| "Iteration 11 Accuracy: 0.88, Loss 25.051902770996094\n", | |
| "Iteration 12 Accuracy: 0.87, Loss 23.67770004272461\n", | |
| "Iteration 13 Accuracy: 0.89, Loss 23.780059814453125\n", | |
| "Iteration 14 Accuracy: 0.89, Loss 22.97209358215332\n", | |
| "Fold 4 Accuracy: 0.89\n" | |
| ] | |
| } | |
| ], | |
| "source": [ | |
| "class Softmax:\n", | |
| " def __init__(self):\n", | |
| " pass\n", | |
| "\n", | |
| " def __call__(self, x):\n", | |
| " x = x - jnp.max(x, axis=-1, keepdims=True) # for stability\n", | |
| " num = jnp.exp(x)\n", | |
| " denum = jnp.sum(num, axis=-1, keepdims=True)\n", | |
| " return num / denum\n", | |
| " pass\n", | |
| "\n", | |
| "class RELU:\n", | |
| " def __init__(self):\n", | |
| " pass\n", | |
| "\n", | |
| " def __call__(self, x: jnp.array):\n", | |
| " return jnp.maximum(x, 0)\n", | |
| "\n", | |
| "class LayerNorm:\n", | |
| " def __init__(self):\n", | |
| " pass\n", | |
| "\n", | |
| " def __call__(self, x, eps=1e-9):\n", | |
| " mean = jnp.mean(x, axis=-1, keepdims=True)\n", | |
| " var = jnp.var(x, axis=-1, keepdims=True)\n", | |
| " return (x - mean)/jnp.sqrt(var + eps)\n", | |
| " pass\n", | |
| "\n", | |
| "class LinearLayer:\n", | |
| " def __init__(self, in_features: int, out_features: int):\n", | |
| " self.in_features = in_features\n", | |
| " self.out_features = out_features\n", | |
| " pass\n", | |
| "\n", | |
| " def init(self, prng_key):\n", | |
| " keys = jax.random.split(prng_key, 2)\n", | |
| " scale = jnp.sqrt(2/self.in_features) # > fan-in initialization to prevent dying nerurons\n", | |
| "\n", | |
| " return {\n", | |
| " \"w\": scale * jax.random.normal(keys[0], shape=(self.in_features, self.out_features)),\n", | |
| " \"b\": jnp.zeros((self.out_features, ))\n", | |
| " }\n", | |
| "\n", | |
| " def __call__(self, params: dict, x: jnp.array):\n", | |
| " out = x @ params[\"w\"] + params[\"b\"]\n", | |
| " return out\n", | |
| " pass\n", | |
| "\n", | |
| "# define a classifier model\n", | |
| "class FashionMNISTClassifier:\n", | |
| " def __init__(self, in_features: int):\n", | |
| " self.linear_0 = LinearLayer(in_features=in_features, out_features=784)\n", | |
| " self.linear_1 = LinearLayer(in_features=784, out_features=512)\n", | |
| " self.linear_2 = LinearLayer(in_features=512, out_features=128)\n", | |
| " self.linear_3 = LinearLayer(in_features=128, out_features=64)\n", | |
| " self.linear_4 = LinearLayer(in_features=64, out_features=10)\n", | |
| " self.softmax = Softmax()\n", | |
| " self.relu = RELU()\n", | |
| " self.layer_norm = LayerNorm()\n", | |
| "\n", | |
| " def init(self, prng_key):\n", | |
| " # initialize params tree\n", | |
| " keys = jax.random.split(prng_key, 5)\n", | |
| " return {\n", | |
| " \"linear_0\": self.linear_0.init(prng_key=keys[0]),\n", | |
| " \"linear_1\": self.linear_1.init(prng_key=keys[1]),\n", | |
| " \"linear_2\": self.linear_2.init(prng_key=keys[2]),\n", | |
| " \"linear_3\": self.linear_3.init(prng_key=keys[3]),\n", | |
| " \"linear_4\": self.linear_4.init(prng_key=keys[4]),\n", | |
| " }\n", | |
| " pass\n", | |
| "\n", | |
| " def forward(self, params: dict, x: jnp.array):\n", | |
| " x = self.linear_0(params=params[\"linear_0\"], x=x)\n", | |
| " x = self.layer_norm(x)\n", | |
| " x = self.relu(x)\n", | |
| " x = self.linear_1(params=params[\"linear_1\"], x=x)\n", | |
| " x = self.layer_norm(x)\n", | |
| " x = self.relu(x)\n", | |
| " x = self.linear_2(params=params[\"linear_2\"], x=x)\n", | |
| " x = self.layer_norm(x)\n", | |
| " x = self.relu(x)\n", | |
| " x = self.linear_3(params=params[\"linear_3\"], x=x)\n", | |
| " x = self.relu(x)\n", | |
| " out = self.linear_4(params=params[\"linear_4\"], x=x)\n", | |
| " return out\n", | |
| "\n", | |
| " def validate(self, params, x, y):\n", | |
| " logits = self.forward(params=params, x=x)\n", | |
| " softmax_probs = self.softmax(logits)\n", | |
| " predicted_labels = jnp.argmax(softmax_probs, axis=-1)\n", | |
| " return predicted_labels == y\n", | |
| "\n", | |
| " def predict(self, params, x, proba=True):\n", | |
| " logits = self.forward(params, x)\n", | |
| " return logits if not proba else self.softmax(logits)\n", | |
| " pass\n", | |
| "\n", | |
| "class AdamState(NamedTuple):\n", | |
| " m: Any # 1st momentum\n", | |
| " v: Any # 2nd momentum\n", | |
| " t: int # step\n", | |
| " pass\n", | |
| "\n", | |
| "class AdamOptimizer:\n", | |
| " def __init__(self, learning_rate, beta_1, beta_2, eps):\n", | |
| " self.learning_rate = learning_rate\n", | |
| " self.beta_1 = beta_1\n", | |
| " self.beta_2 = beta_2\n", | |
| " self.epsilon = eps\n", | |
| " pass\n", | |
| "\n", | |
| " def init(self, params) -> AdamState:\n", | |
| " # for each param, initiate 1st and 2nd order momentum\n", | |
| " m = jax.tree.map(lambda p: jnp.zeros_like(p), params)\n", | |
| " v = jax.tree.map(lambda p: jnp.zeros_like(p), params)\n", | |
| " return AdamState(m=m, v=v, t=0)\n", | |
| "\n", | |
| " def update(self, grad, state: AdamState, params: dict=None):\n", | |
| " m, v, t = state.m, state.v, state.t\n", | |
| " updated_m = jax.tree.map(\n", | |
| " lambda m_val, g_val: self.beta_1 * m_val + (1 - self.beta_1) * g_val,\n", | |
| " m,\n", | |
| " grad\n", | |
| " )\n", | |
| " updated_v = jax.tree.map(\n", | |
| " lambda v_val, g_val: self.beta_2 * v_val + (1 - self.beta_2) * (g_val ** 2),\n", | |
| " v,\n", | |
| " grad\n", | |
| " )\n", | |
| "\n", | |
| " # apply bias correction\n", | |
| " t_next = t + 1\n", | |
| " bias_correction_2 = 1 - self.beta_2 ** t_next\n", | |
| " bias_correction_1 = 1 - self.beta_1 ** t_next\n", | |
| "\n", | |
| " updated_grad = jax.tree.map(\n", | |
| " lambda m_, v_: - self.learning_rate * (m_/bias_correction_1) / (jnp.sqrt(v_/bias_correction_2) + self.epsilon),\n", | |
| " updated_m,\n", | |
| " updated_v\n", | |
| " )\n", | |
| "\n", | |
| " # update the states\n", | |
| " updated_state = AdamState(\n", | |
| " m = updated_m,\n", | |
| " v = updated_v,\n", | |
| " t = t_next\n", | |
| " )\n", | |
| "\n", | |
| " return updated_grad, updated_state\n", | |
| " pass\n", | |
| " pass\n", | |
| "\n", | |
| "class SGDOptimizer:\n", | |
| " def __init__(self, learning_rate):\n", | |
| " self.learning_rate = learning_rate\n", | |
| " pass\n", | |
| "\n", | |
| " def init(self):\n", | |
| " return None\n", | |
| " pass\n", | |
| "\n", | |
| " def update(self, grad, state, params: dict=None):\n", | |
| " updated_grad = jax.tree.map(\n", | |
| " lambda g: - self.learning_rate * g, grad\n", | |
| " )\n", | |
| "\n", | |
| " return updated_grad, state\n", | |
| " pass\n", | |
| "\n", | |
| "def weight_decay_l2(params):\n", | |
| " return jax.tree.reduce(\n", | |
| " lambda x, y: x + y,\n", | |
| " jax.tree.map(\n", | |
| " lambda p: jnp.sum(p**2), params\n", | |
| " ))\n", | |
| "\n", | |
| "def loss_fn(params: dict, model: FashionMNISTClassifier, x: jnp.array, y_true: jnp.array):\n", | |
| " # calculate cross entropy loss\n", | |
| " # sum(y_t * log(y_pred))\n", | |
| " logits = model.forward(params=params, x=x)\n", | |
| " assert y_true.shape[0] == logits.shape[0], f\"mismatching shape between ground truth {y_true.shape} and logits {logits.shape}\"\n", | |
| " log_probs = logits - jax.nn.logsumexp(logits, axis=-1, keepdims=True)\n", | |
| " return -jnp.mean(jnp.sum(y_true * log_probs, axis=-1))\n", | |
| "\n", | |
| "def apply_updates(grad, params, weight_decay=1e-5):\n", | |
| " \"\"\" apply updated grads to the params \"\"\"\n", | |
| " # apply grad of same keys to params of same key\n", | |
| " return jax.tree.map(\n", | |
| " lambda p, g: p + g - weight_decay * p, params, grad # > decoupled weight decay\n", | |
| " )\n", | |
| "\n", | |
| "@partial(jax.jit, static_argnums=[0, 2])\n", | |
| "def training_step(model, params, optim, optim_state, x_sample, y_sample):\n", | |
| " y_ohe = jax.nn.one_hot(y_sample, num_classes=10)\n", | |
| " loss, grad = jax.value_and_grad(loss_fn)(params, model, x_sample, y_ohe)\n", | |
| " updated_grad, optim_state = optim.update(grad=grad, state=optim_state)\n", | |
| " params = apply_updates(updated_grad, params)\n", | |
| " return params, optim_state, loss\n", | |
| "\n", | |
| "features = jnp.array(fashion_mnist.data)\n", | |
| "targets = jnp.array(fashion_mnist.target.astype(\"int\"))\n", | |
| "X, X_test, y, y_test = train_test_split(features, targets, test_size=0.1, shuffle=True, random_state=42)\n", | |
| "X = X/255. # remove the normalization and see what happens if weights are initialized to normal distribution to 0 mean and 1 var\n", | |
| "\n", | |
| "prng_key = jax.random.key(seed=42)\n", | |
| "\n", | |
| "skf = StratifiedKFold(5, shuffle=True, random_state=42)\n", | |
| "folds = skf.split(X, y)\n", | |
| "for fold, (train_idx, val_idx) in enumerate(folds):\n", | |
| " print(f\"------------------ Fold {fold} ------------------\")\n", | |
| "\n", | |
| " X_train, y_train = X[train_idx], y[train_idx]\n", | |
| " X_val, y_val = X[val_idx], y[val_idx]\n", | |
| "\n", | |
| " train_data_loader = DataLoader(X_train, y_train, 512)\n", | |
| "\n", | |
| " model = FashionMNISTClassifier(in_features=x_sample.shape[1])\n", | |
| " params = model.init(prng_key=prng_key)\n", | |
| " # optim = SGDOptimizer(learning_rate=0.01)\n", | |
| " optim = AdamOptimizer(learning_rate=0.01, beta_1=0.9, beta_2=0.999, eps=1e-9)\n", | |
| " optim_state = optim.init(params=params)\n", | |
| " n_iterations = 15\n", | |
| "\n", | |
| " for iteration in range(n_iterations):\n", | |
| " train_data = train_data_loader.get_data(is_train=True)\n", | |
| " iteration_loss = 0\n", | |
| " for step, (x_sample, y_sample) in enumerate(train_data):\n", | |
| " params, optim_state, loss = training_step(model, params, optim, optim_state, x_sample, y_sample)\n", | |
| " iteration_loss += loss\n", | |
| "\n", | |
| " result = model.validate(params, X_val, y_val)\n", | |
| " print(f\"Iteration {iteration} Accuracy: {result.sum()/result.shape[0]:.2f}, Loss {iteration_loss}\")\n", | |
| "\n", | |
| " fold_result = model.validate(params, X_val, y_val)\n", | |
| " print(f\"Fold {fold} Accuracy: {result.sum()/result.shape[0]:.2f}\")" | |
| ] | |
| }, | |
| { | |
| "cell_type": "code", | |
| "execution_count": 10, | |
| "id": "3010bd64", | |
| "metadata": {}, | |
| "outputs": [ | |
| { | |
| "name": "stdout", | |
| "output_type": "stream", | |
| "text": [ | |
| "11.0 {'x': Array(3., dtype=float32, weak_type=True), 'y': Array(2., dtype=float32, weak_type=True)}\n" | |
| ] | |
| } | |
| ], | |
| "source": [ | |
| "def f(params: dict): return params['x']**3 + params['y']**2 + 9\n", | |
| "\n", | |
| "loss, grad = jax.value_and_grad(f)({'x': 1., 'y': 1.}) # tuple[0] -> evaluates the function, tuple[1] -> evaluates the grad\n", | |
| "print(loss, grad)" | |
| ] | |
| }, | |
| { | |
| "cell_type": "code", | |
| "execution_count": 11, | |
| "id": "41f8f065", | |
| "metadata": {}, | |
| "outputs": [ | |
| { | |
| "data": { | |
| "text/plain": [ | |
| "6" | |
| ] | |
| }, | |
| "execution_count": 11, | |
| "metadata": {}, | |
| "output_type": "execute_result" | |
| } | |
| ], | |
| "source": [ | |
| "sample_data = {\n", | |
| " \"x\": {\n", | |
| " \"a\": 1,\n", | |
| " \"b\": 2,\n", | |
| " },\n", | |
| " \"y\": {\n", | |
| " \"c\": 1,\n", | |
| " \"d\": 2\n", | |
| " }\n", | |
| "}\n", | |
| "\n", | |
| "jax.tree.reduce(lambda x, y: x + y, tree=sample_data) # reduce operates between two leaves and cumulates over" | |
| ] | |
| }, | |
| { | |
| "cell_type": "code", | |
| "execution_count": 15, | |
| "id": "a54a95ff", | |
| "metadata": {}, | |
| "outputs": [ | |
| { | |
| "name": "stdout", | |
| "output_type": "stream", | |
| "text": [ | |
| "Accuracy: 0.87\n" | |
| ] | |
| } | |
| ], | |
| "source": [ | |
| "# predict\n", | |
| "probs = model.predict(params, X_test)\n", | |
| "y_preds = jnp.argmax(probs, axis=-1)\n", | |
| "result = y_preds == y_test\n", | |
| "print(f\"Accuracy: {result.sum()/result.shape[0]:.2f}\")" | |
| ] | |
| }, | |
| { | |
| "cell_type": "code", | |
| "execution_count": null, | |
| "id": "2e050283", | |
| "metadata": {}, | |
| "outputs": [], | |
| "source": [] | |
| } | |
| ], | |
| "metadata": { | |
| "kernelspec": { | |
| "display_name": "cml", | |
| "language": "python", | |
| "name": "python3" | |
| }, | |
| "language_info": { | |
| "codemirror_mode": { | |
| "name": "ipython", | |
| "version": 3 | |
| }, | |
| "file_extension": ".py", | |
| "mimetype": "text/x-python", | |
| "name": "python", | |
| "nbconvert_exporter": "python", | |
| "pygments_lexer": "ipython3", | |
| "version": "3.11.10" | |
| } | |
| }, | |
| "nbformat": 4, | |
| "nbformat_minor": 5 | |
| } |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment