Skip to content

Instantly share code, notes, and snippets.

@iankronquist
Created December 15, 2025 00:04
Show Gist options
  • Select an option

  • Save iankronquist/34a0b8e5583b654ea84e6aa75d69e14e to your computer and use it in GitHub Desktop.

Select an option

Save iankronquist/34a0b8e5583b654ea84e6aa75d69e14e to your computer and use it in GitHub Desktop.
igptv4 llm training 2025-12-14
'''
We don't have enough disk to unpack tokenized copies of the fineweb dataset, so tokenize as we go.
'''
import os
import time
import tiktoken
import random
from enum import Enum
from tqdm import tqdm
import torch
import typing
import mlx.core
from datasets import load_dataset
from simple_data_loader import SimpleDataLoaderBackend
class FineWebDataLoader:
def __init__(self, B: int, T: int, file_path: str, backend: int, device='mps', subset='train', start_column=None, feature='text', shuffle=False):
self.durations = []
self.shuffle = shuffle
self.B = B
self.T = T
self.device = device
self.backend = backend
if isinstance(file_path, list):
parquets = file_path
else:
parquets = [os.path.join(file_path, fname) for fname in os.listdir(file_path) if fname.endswith('parquet')]
self.dataset: typing.List = load_dataset('parquet', data_files=parquets)[subset][feature]
self.tokens = None
self.tokens_pos = 0
if self.shuffle:
self.shuffled_indices = (list(range(len(self.dataset))))
random.shuffle(self.shuffled_indices)
self.encoder = tiktoken.get_encoding('gpt2')
self.eot = self.encoder._special_tokens['<|endoftext|>']
self.epoch = 0
self.current_column = 0
self.start_column = start_column
if start_column is not None:
self.current_column = start_column
self.current_position = 0
if self.backend == SimpleDataLoaderBackend.Torch:
self.dtype = torch.int32
self.batch = torch.empty((B * T + 1), dtype=self.dtype)
else:
self.dtype = mlx.core.uint16
self.batch = mlx.core.zeros((B * T + 1), dtype=self.dtype)
print('Batch size', B * T)
#total_tokens = 10e9
print(f"1 epoch = {self.batches_per_epoch()} batches")
#self.progress_bar = tqdm(total=len(self.dataset), unit='cols', desc='cols ')
self.progress_bar = None
self.encoded = None
self.next_encoded = None
#def __del__(self):
# self.progress_bar.close()
def batches_per_epoch(self):
total_tokens = 10e9
return total_tokens / (self.B * self.T)
def percent(self):
return self.current_column / len(self.dataset) * 100
def get_column(self, index):
if self.shuffle:
index = self.shuffled_indices[index]
return self.dataset[index]
def encode(self):
if self.next_encoded is None:
t0 = time.time()
tokens_list = [self.eot] + self.encoder.encode(self.get_column(self.current_column), allowed_special={'<|endoftext|>',})
if self.backend == SimpleDataLoaderBackend.Torch:
self.next_encoded = torch.tensor(tokens_list, dtype=self.dtype)
else:
self.next_encoded = mlx.core.array(tokens_list, dtype=self.dtype)
t1 = time.time()
self.durations.append(t1 - t0)
def get_encoded(self):
if self.encoded is None:
#if self.next_encoded is None:
# self.encode()
assert self.next_encoded is not None
self.encoded = self.next_encoded
self.next_encoded = None
return self.encoded
def reset(self):
self.current_position = 0
if self.start_column is not None:
self.current_column = self.start_column
else:
self.start_column = 0
def view(self, tensor, B, T):
#match self.backend:
if self.backend == SimpleDataLoaderBackend.Torch:
return tensor.view(B, T).to(self.device)
else: # SimpleDataLoaderBackend.Mlx:
return tensor.reshape((B, T))
def next_column(self):
#print('column', self.current_column)
if self.current_column >= len(self.dataset):
self.current_column = 0
self.epoch += 1
if self.progress_bar:
self.progress_bar.reset()
if self.progress_bar:
self.progress_bar.update(1)
#t0 = time.time()
#if self.backend == SimpleDataLoaderBackend.Torch:
# self.tokens = torch.tensor([self.eot] + self.encoder.encode(self.dataset[self.current_column], allowed_special={'<|endoftext|>',}), dtype=self.dtype)
#else:
# self.tokens = mlx.core.array([self.eot] + self.encoder.encode(self.dataset[self.current_column], allowed_special={'<|endoftext|>',}), dtype=self.dtype)
#t1 = time.time()
#self.durations.append(t1 - t0)
self.tokens = self.get_encoded()
self.tokens_pos = 0
self.current_column += 1
return self.tokens
def peek_batch(self):
B, T = self.B, self.T
x = self.view(self.batch[:-1], B, T)
y = self.view(self.batch[1:], B, T)
return x, y
def next_batch(self):
batch_pos = 0
while batch_pos < len(self.batch):
if self.tokens_pos == 0:
#self.batch[batch_pos] = self.eot
#batch_pos += 1
if batch_pos == len(self.batch):
break
self.next_column()
batch_remaining = len(self.batch) - batch_pos
assert self.tokens is not None
tokens_remaining = len(self.tokens) - self.tokens_pos
chunk_len = min(batch_remaining, tokens_remaining)
self.batch[batch_pos:batch_pos+chunk_len] = self.tokens[self.tokens_pos:self.tokens_pos+chunk_len]
batch_pos += chunk_len
self.tokens_pos += chunk_len
if self.tokens_pos >= len(self.tokens):
self.tokens_pos = 0
self.tokens = None
return self.peek_batch()
# Based on the Karpathy video
# Let's reproduce GPT-2 (124M)
# https://www.youtube.com/watch?v=l8pRSuU81PU
import dataclasses
from dataclasses import dataclass
import copy
import math
from typing import Optional, Tuple
import mlx
from mlx import nn
from mlx.core import array as Tensor
import mlx.core
@dataclass
class IGptV4Config:
use_kv_cache: bool = False
vocab_size: int = 8192
n_layer: int = 12
n_head: int = 12 # For multi-head attention
n_embd: int = 768 # dmodel
n_kv_heads: int = 4
dtype: mlx.core.Dtype = mlx.core.bfloat16
mlp_ratio: int = 4
def to_dict(self):
copied_config = copy.copy(self)
copied_config.dtype = str(copied_config.dtype)
return dataclasses.asdict(copied_config)
def _init_weights(name, module):
if isinstance(module, nn.Linear):
module.weight[:] = mlx.core.random.normal(
shape=module.weight.shape,
loc=0.0,
scale=0.02
)
if module.get('bias') is not None:
module.bias[:] = mlx.core.zeros_like(module.bias)
elif isinstance(module, nn.Embedding):
module.weight[:] = mlx.core.random.normal(
shape=module.weight.shape,
loc=0.0,
scale=0.02
)
# We don't need to initialize RMSNorm because they're initialized with all 1s
class Mlp(nn.Module):
def __init__(self, config):
super().__init__()
#self.gate = nn.Linear(config.n_embd, config.n_embd * config.mlp_ratio, bias=False)
self.c_fc = nn.Linear(config.n_embd, config.n_embd * config.mlp_ratio, bias=False)
#self.activation = nn.SiLU()
self.activation = nn.GELU(approx='fast')
self.c_proj = nn.Linear(config.n_embd * config.mlp_ratio, config.n_embd, bias=False)
self.c_fc.set_dtype(config.dtype)
self.activation.set_dtype(config.dtype)
self.c_proj.set_dtype(config.dtype)
def __call__(self, x: Tensor) -> Tensor:
#x1 = self.gate(x)
x = self.c_fc(x)
x = self.activation(x) #* x1
x = self.c_proj(x)
return x
class Block(nn.Module):
def __init__(self, config, rope):
super().__init__()
self.attn = CausalSelfAttention(config, rope)
self.ln_1 = nn.RMSNorm(config.n_embd)
self.mlp = Mlp(config)
self.ln_2 = nn.RMSNorm(config.n_embd)
def __call__(self, x: Tensor) -> Tensor:
x = x + self.attn(self.ln_1(x))
x = x + self.mlp(self.ln_2(x))
return x
class CausalSelfAttention(nn.Module):
def __init__(self, config, rope):
super().__init__()
self.n_embd = config.n_embd
self.n_head = config.n_head
self.n_kv_heads = config.n_kv_heads
self.n_rep = config.n_head // self.n_kv_heads
self.head_dim = config.n_embd // config.n_head
out_dim = self.n_embd + 2 * self.head_dim * self.n_kv_heads
self.c_attn = nn.Linear(self.n_embd, out_dim, bias=False)
self.c_proj = nn.Linear(config.n_embd, config.n_embd, bias=False)
self.c_proj.set_dtype(config.dtype)
self.c_attn.set_dtype(config.dtype)
self.k_cache = None
self.v_cache = None
self.use_kv_cache = config.use_kv_cache
self.rope = rope
self.scale = 1.0 / math.sqrt(self.head_dim)
def __call__(self, x: Tensor) -> Tensor:
B, T, C = x.shape
qkv = self.c_attn(x)
head_dim = self.n_embd // self.n_head
q_size = self.n_head * head_dim
kv_size = self.n_kv_heads * head_dim
q = qkv[:, :, :q_size]
k = qkv[:, :, q_size:(q_size+kv_size)]
v = qkv[:, :, (q_size+kv_size):]
# got = mlx.core.split(qkv, [q_size, kv_size, kv_size], axis=-1)
# print('got', [(x.shape) for x in got], len(got))
# import pdb;pdb.set_trace()
# got = mlx.core.split(
# qkv,
# #[self.n_embd, 2 * self.head_dim * self.n_kv_heads],
# self.n_embd,
# axis=2
# )
# q, k, v = got
# k, v = mlx.core.split(
# kv,
# 2 * self.head_dim * self.n_kv_heads,
# axis=2
# )
q = q.reshape(B, T, self.n_head, C // self.n_head)
q = mlx.core.swapaxes(q, 1, 2)
k = k.reshape(B, T, self.n_kv_heads, self.head_dim)
k = mlx.core.swapaxes(k, 1, 2)
v = v.reshape(B, T, self.n_kv_heads, self.head_dim)
v = mlx.core.swapaxes(v, 1, 2)
q = self.rope(q)
k = self.rope(k)
if self.use_kv_cache:
if self.k_cache is None:
self.k_cache = k
self.v_cache = v
else:
self.k_cache = mlx.core.concatenate([self.k_cache, k], axis=2)
self.v_cache = mlx.core.concatenate([self.v_cache, v], axis=2)
k = self.k_cache
v = self.v_cache
y = mlx.core.fast.scaled_dot_product_attention(q, k, v, scale=self.scale, mask="causal")
y = y.transpose(0, 2, 1, 3).reshape([B, T, C])
if self.use_kv_cache:
y = y[:, -1, :].reshape([B, 1, C])
y = self.c_proj(y)
return y
def reset_key_value_cache(self, enable_key_value_cache: bool):
self.k_cache = None
self.v_cache = None
self.use_kv_cache = enable_key_value_cache
class Transformer(nn.Module):
def __init__(self, config):
super().__init__()
self.wte = nn.Embedding(config.vocab_size, config.n_embd)
rope = nn.RoPE(config.n_embd//config.n_head)
self.h = nn.Sequential(*[Block(config, rope) for _ in range(config.n_layer)])
self.ln_f = nn.RMSNorm(config.n_embd)
self.wte.set_dtype(config.dtype)
self.ln_f.set_dtype(config.dtype)
self.h.set_dtype(config.dtype)
#for h in self.h:
# h.set_dtype(config.dtype)
class IGptV4(nn.Module):
def __init__(self, config):
super().__init__()
self.config = config
self.transformer = Transformer(config)
self.lm_head = nn.Linear(config.n_embd, config.vocab_size, bias=False)
self.lm_head.set_dtype(config.dtype)
# In order to get the nice loss curve karpathy gets we need both of the following tricks, sharing the
# token embedding weights and the output weights, as well as initializing the weights the way he does
# init params
self.apply_to_modules(_init_weights)
self.transformer.wte.weight = self.lm_head.weight
def __call__(self, ids: Tensor, targets: Optional[Tensor]=None) -> Tuple[Tensor, Optional[Tensor]]:
x: Tensor = self.transformer.wte(ids)
x = self.transformer.h(x)
#for (i, layer) in enumerate(self.transformer.h):
# x = layer(x)
x = self.transformer.ln_f(x)
logits = self.lm_head(x)
loss = None
return logits, loss
def reset_key_value_cache(self, enable_key_value_cache: bool):
for layer in self.transformer.h.layers:
layer.attn.reset_key_value_cache(enable_key_value_cache)
'''
I wanted to try using the pleias synth dataset. I created two data loaders, one for the synthetic seed text, and the other for the reasoning traces.
I filter out non-English columns. I use a custom tokenization for think phases which I later learned was different than what Pleias uses.
But I'm focusing on the pre-train phase right now.
'''
import os
import time
import tiktoken
from tqdm import tqdm
import typing
import mlx.core
from datasets import load_dataset
class QaReasoningDataLoader:
def __init__(self, B: int, T: int, file_path: str, subset='train', start_column=None, tokenizer_name='gpt2_custom'):
self.B = B
self.T = T
if isinstance(file_path, list):
parquets = file_path
else:
parquets = [os.path.join(file_path, fname) for fname in os.listdir(file_path) if fname.endswith('parquet')]
self.dataset: typing.List = load_dataset('parquet', data_files=parquets)[subset]
self.tokens = None
self.tokens_pos = 0
if tokenizer_name == 'gpt2_custom':
gpt2_encoder = tiktoken.get_encoding('gpt2')
n_vocab = gpt2_encoder.n_vocab
self.extra_tokens = {
"<user>": n_vocab + 1,
"</user>": n_vocab + 2,
"<agent>": n_vocab + 3,
"</agent>": n_vocab + 4,
"<think>": n_vocab + 5,
"</think>": n_vocab + 6,
}
self.encoder = tiktoken.Encoding(
name="gpt2_custom",
pat_str=gpt2_encoder._pat_str,
mergeable_ranks=gpt2_encoder._mergeable_ranks,
special_tokens={**gpt2_encoder._special_tokens, **self.extra_tokens},
)
elif tokenizer_name == 'monad':
from tokenizers import Tokenizer
self.encoder = Tokenizer.from_file("monad_tokenizer.json")
else:
raise NotImplementedError(tokenizer_name)
self.eot = self.encoder._special_tokens['<|endoftext|>']
self.epoch = 0
self.current_column = 0
self.start_column = start_column
if start_column is not None:
self.current_column = start_column
self.current_position = 0
self.dtype = mlx.core.uint16
self.batch = mlx.core.zeros((B * T + 1), dtype=self.dtype)
#print('Batch size', B * T)
#total_tokens = 10e9
#print(f"1 epoch = {self.batches_per_epoch()} batches")
self.progress_bar = None
#self.progress_bar = tqdm(total=len(self.dataset), unit='cols', desc='cols ')
self.encoded = None
self.next_encoded = None
#def __del__(self):
# self.progress_bar.close()
def batches_per_epoch(self):
total_tokens = 10e9
return total_tokens / (self.B * self.T)
def percent(self):
return self.current_column / len(self.dataset) * 100
def column_text(self, column):
query = column['query']
synthetic_reasoning = column['synthetic_reasoning']
synthetic_answer = column['synthetic_answer']
return f'''<user>{query}</user><think>{synthetic_reasoning}</think><agent>{synthetic_answer}</agent><|endoftext|>'''
def encode(self):
if self.next_encoded is None:
t0 = time.time()
#tokens_list = [self.eot] + self.encoder.encode(self.dataset[self.current_column], allowed_special={'<|endoftext|>',})
tokens_list = self.encoder.encode(self.column_text(self.dataset[self.current_column]), allowed_special='all')
self.next_encoded = mlx.core.array(tokens_list, dtype=self.dtype)
t1 = time.time()
#self.durations.append(t1 - t0)
def get_encoded(self):
if self.encoded is None:
#if self.next_encoded is None:
# self.encode()
assert self.next_encoded is not None
self.encoded = self.next_encoded
self.next_encoded = None
return self.encoded
def reset(self):
self.current_position = 0
if self.start_column is not None:
self.current_column = self.start_column
else:
self.start_column = 0
def view(self, tensor, B, T):
return tensor.reshape((B, T))
def next_column(self):
#print('column', self.current_column)
if self.current_column >= len(self.dataset):
self.current_column = 0
self.epoch += 1
if self.progress_bar:
self.progress_bar.reset()
#t0 = time.time()
#if self.backend == SimpleDataLoaderBackend.Torch:
# self.tokens = torch.tensor([self.eot] + self.encoder.encode(self.dataset[self.current_column], allowed_special={'<|endoftext|>',}), dtype=self.dtype)
#else:
# self.tokens = mlx.core.array([self.eot] + self.encoder.encode(self.dataset[self.current_column], allowed_special={'<|endoftext|>',}), dtype=self.dtype)
#t1 = time.time()
#self.durations.append(t1 - t0)
self.tokens = self.get_encoded()
self.tokens_pos = 0
while True:
self.current_column += 1
if self.progress_bar:
self.progress_bar.update(1)
if self.dataset[self.current_column]['language'] == 'en':
break
return self.tokens
def peek_batch(self):
B, T = self.B, self.T
x = self.view(self.batch[:-1], B, T)
y = self.view(self.batch[1:], B, T)
return x, y
def next_batch(self):
batch_pos = 0
while batch_pos < len(self.batch):
if self.tokens_pos == 0:
self.batch[batch_pos] = self.eot
batch_pos += 1
if batch_pos == len(self.batch):
break
self.next_column()
batch_remaining = len(self.batch) - batch_pos
assert self.tokens is not None
tokens_remaining = len(self.tokens) - self.tokens_pos
chunk_len = min(batch_remaining, tokens_remaining)
self.batch[batch_pos:batch_pos+chunk_len] = self.tokens[self.tokens_pos:self.tokens_pos+chunk_len]
batch_pos += chunk_len
self.tokens_pos += chunk_len
if self.tokens_pos >= len(self.tokens):
self.tokens_pos = 0
self.tokens = None
return self.peek_batch()
class QaReasoningSeedTextDataLoader(QaReasoningDataLoader):
def column_text(self, column):
seed_text = column['query_seed_text']
return f'''{seed_text}<|endoftext|>'''
#!/usr/bin/env python3
# Training script for IGptV4
# This training script was mostly generated using the following model since I lost faith in my own
# training script and am flailing a bit right now. However, I'm still having the same issues.
# Model: Claude Sonnet 4.5, Date: 2025-12-12
import random
import mlx.core as mx
import mlx.nn as nn
import mlx.optimizers as optim
from mlx.utils import tree_map, tree_flatten
from igptv4 import IGptV4, IGptV4Config
from q_and_a_reasoning_data_loader import QaReasoningDataLoader, QaReasoningSeedTextDataLoader
from fineweb_data_loader import FineWebDataLoader
from simple_data_loader import SimpleDataLoaderBackend, SimpleDataLoader
import time
import tiktoken
import math
import wandb
from datetime import datetime
import os
def count_params(model):
"""Count model parameters, avoiding double counting tied weights"""
seen_weights = set()
total = 0
def count_array(arr):
nonlocal total, seen_weights
arr_id = id(arr)
if arr_id not in seen_weights:
seen_weights.add(arr_id)
total += arr.size
def traverse(node):
if isinstance(node, mx.array):
count_array(node)
elif isinstance(node, dict):
for v in node.values():
traverse(v)
elif isinstance(node, (list, tuple)):
for item in node:
traverse(item)
traverse(model.parameters())
return total
def format_time(seconds):
"""Format seconds into human readable time"""
if seconds < 60:
return f"{seconds:.0f}s"
elif seconds < 3600:
return f"{seconds//60:.0f}m {seconds%60:.0f}s"
else:
return f"{seconds//3600:.0f}h {(seconds%3600)//60:.0f}m"
def get_lr(step, warmup_steps, max_steps, max_lr, min_lr):
"""
Learning rate schedule with warmup and cosine decay.
Must be compilable by MLX (no conditionals based on step).
"""
# Linear warmup
warmup_lr = max_lr * step / warmup_steps
# Cosine decay after warmup
decay_steps = max_steps - warmup_steps
decay_progress = (step - warmup_steps) / decay_steps
cosine_decay = 0.5 * (1.0 + mx.cos(math.pi * decay_progress))
decay_lr = min_lr + (max_lr - min_lr) * cosine_decay
# Use mx.where for conditional selection (compilable)
lr = mx.where(step < warmup_steps, warmup_lr, decay_lr)
return lr
def loss_fn(model, inputs, targets):
"""Compute cross-entropy loss"""
logits, _ = model(inputs)
B, T, V = logits.shape
logits = logits.reshape(B * T, V)
targets = targets.reshape(B * T)
loss = nn.losses.cross_entropy(logits, targets, reduction='mean')
return loss
def train_step(model, optimizer, inputs, targets, accumulated_grads=None):
"""Single training step with gradient computation"""
loss_and_grad_fn = nn.value_and_grad(model, loss_fn)
loss, grads = loss_and_grad_fn(model, inputs, targets)
# Force evaluation to free computation graph
mx.eval(loss)
# Accumulate gradients with tree structure handling
if accumulated_grads is None:
accumulated_grads = tree_map(lambda g: g.astype(mx.float32), grads)
else:
accumulated_grads = tree_map(lambda a, g: a + g, accumulated_grads, grads)
# Force evaluation of accumulated gradients to free intermediate graphs
mx.eval(accumulated_grads)
return loss, accumulated_grads
def validate(model, val_loader, num_batches=50):
"""Run validation and return average loss and perplexity"""
total_loss = 0.0
for _ in range(num_batches):
val_loader.encode()
inputs, targets = val_loader.next_batch()
loss = loss_fn(model, inputs, targets)
mx.eval(loss)
total_loss += loss.item()
avg_loss = total_loss / num_batches
perplexity = math.exp(avg_loss)
return avg_loss, perplexity
def compute_gradient_health_metrics(grads):
"""
Compute comprehensive gradient health metrics.
Returns dict with various gradient statistics.
"""
flat_grads = list(tree_flatten(grads, destination={}).values())
# Global statistics
grad_values = mx.concatenate([g.reshape(-1) for g in flat_grads])
metrics = {
'grad_norm': float(mx.sqrt(mx.sum(grad_values * grad_values)).item()),
'grad_mean': float(mx.mean(grad_values).item()),
'grad_std': float(mx.std(grad_values).item()),
'grad_max': float(mx.max(mx.abs(grad_values)).item()),
'grad_min': float(mx.min(mx.abs(grad_values)).item()),
}
# Per-layer statistics (first, middle, last layers)
layer_grads = {}
for key, grad in tree_flatten(grads, destination={}).items():
# Extract layer number if present
if 'layers.' in key:
layer_num = key.split('layers.')[1].split('.')[0]
if layer_num not in layer_grads:
layer_grads[layer_num] = []
layer_grads[layer_num].append(grad)
if layer_grads:
layer_nums = sorted([int(k) for k in layer_grads.keys()])
if layer_nums:
first_layer = str(layer_nums[0])
last_layer = str(layer_nums[-1])
# First layer norm
first_layer_values = mx.concatenate([g.reshape(-1) for g in layer_grads[first_layer]])
metrics['grad_norm_layer_first'] = float(mx.sqrt(mx.sum(first_layer_values * first_layer_values)).item())
# Last layer norm
last_layer_values = mx.concatenate([g.reshape(-1) for g in layer_grads[last_layer]])
metrics['grad_norm_layer_last'] = float(mx.sqrt(mx.sum(last_layer_values * last_layer_values)).item())
# Ratio (useful for detecting vanishing/exploding)
if metrics['grad_norm_layer_first'] > 0:
metrics['grad_ratio_last_to_first'] = metrics['grad_norm_layer_last'] / metrics['grad_norm_layer_first']
else:
metrics['grad_ratio_last_to_first'] = 0.0
return metrics
def chinchilla_optimal_tokens(num_params):
"""Calculate Chinchilla-optimal token budget (20 tokens per parameter)"""
return num_params * 20
gpt2_tokenizer = tiktoken.get_encoding('gpt2')
extra_tokens = {
"<user>": gpt2_tokenizer.n_vocab + 1,
"</user>": gpt2_tokenizer.n_vocab + 2,
"<agent>": gpt2_tokenizer.n_vocab + 3,
"</agent>": gpt2_tokenizer.n_vocab + 4,
"<think>": gpt2_tokenizer.n_vocab + 5,
"</think>": gpt2_tokenizer.n_vocab + 6,
}
tokenizer = tiktoken.Encoding(
name="gpt2_custom",
pat_str=gpt2_tokenizer._pat_str,
mergeable_ranks=gpt2_tokenizer._mergeable_ranks,
special_tokens={**gpt2_tokenizer._special_tokens, **extra_tokens},
)
def tokenizer_decode(tokens):
decoded = tokenizer.decode(tokens)
return decoded
def tokenizer_encode(text):
decoded = tokenizer.encode(text)
return decoded
@mx.compile
def sample_topk(logits: mx.array, temperature: float=0.5, k: int = 5) -> mx.array:
logits /= temperature
logits = mx.softmax(logits)
args = mx.argpartition(logits, k, axis=-1)[..., -k:]
random = mx.random.randint(0, k)
return args[random]
def infer(model, max_tokens=50, prompt='Once upon a time,'):
tokens = tokenizer_encode(prompt)
model.train(False)
model.reset_key_value_cache(True)
for _ in (range(max_tokens)):
token_id_array = mx.array(tokens, dtype=mx.uint16)
token_id_array = token_id_array.reshape((1, *token_id_array.shape))
result, loss = model(token_id_array, None)
result = result[-1,-1]
chosen = sample_topk(result)
del token_id_array
tokens.append(chosen.item())
try:
decoded = tokenizer_decode(tokens)
except KeyError as e:
print(e)
decoded = ''
model.reset_key_value_cache(False)
model.train(True)
return decoded
def main():
# Base hyperparameters (adjustable)
micro_batch_size = 8 # Batch size per gradient step
sequence_length = 1024 # sequence length
# We want a random but predictable between run shuffle of the dataset.
rng_seed = 42
random.seed(rng_seed)
# Training configuration
max_learning_rate = 1e-4 # Increased from 6e-4 due to vanishing gradients
min_learning_rate = 0.1 * max_learning_rate
weight_decay = 0.1
grad_clip_norm = 1.0 # Gradient clipping threshold (1.0 is standard for transformers)
log_interval = 100
val_interval = 500
save_interval = 1000
infer_interval = 500
grad_health_interval = 200 # Check gradient health metrics
overwrite_checkpoints = True # If True, only keep the last checkpoint to save disk space
memory_limit = 20 * 1024 ** 3 # 20GB memory limit
should_save = True # Set to False to disable all saving (checkpoints, wandb) for testing
use_wandb = True # Set to False to disable wandb logging (overridden by should_save)
# Warmup will be calculated as a fraction of max_steps later
warmup_fraction = 0.05 # 5% of training for warmup
# Generate timestamp for this training run
run_timestamp = datetime.now().strftime("%Y%m%d_%H%M%S")
# Create output directory for this run (only if saving)
if should_save:
output_dir = f"runs/run_{run_timestamp}"
os.makedirs(output_dir, exist_ok=True)
print(f"Output directory: {output_dir}")
else:
output_dir = None
print("Saving disabled - running in test mode")
# Model configuration
config = IGptV4Config(
vocab_size=50304,
n_layer=12,
n_head=12,
n_embd=768,
n_kv_heads=4,
dtype=mx.bfloat16,
mlp_ratio=4,
use_kv_cache=False
)
print("Initializing model...")
model = IGptV4(config)
mx.eval(model.parameters())
# Count parameters (accounting for tied weights)
num_params = count_params(model)
# Calculate Chinchilla-optimal token budget
chinchilla_tokens = chinchilla_optimal_tokens(num_params)
# Decide on a reasonable batch size (e.g., 0.5M tokens per batch)
target_tokens_per_batch = 64 * 1024
# Calculate gradient accumulation steps
tokens_per_micro_batch = micro_batch_size * sequence_length
grad_accum_steps = target_tokens_per_batch // tokens_per_micro_batch
# Calculate effective batch size
effective_batch_size = micro_batch_size * grad_accum_steps
effective_tokens_per_batch = effective_batch_size * sequence_length
# Calculate number of steps needed to reach Chinchilla-optimal tokens
# Each step processes one micro batch
max_steps = chinchilla_tokens // tokens_per_micro_batch
# Calculate warmup steps as a fraction of total steps
warmup_steps = int(warmup_fraction * max_steps)
# Calculate number of optimizer updates
num_optimizer_updates = max_steps // grad_accum_steps
# Total tokens that will be seen
total_tokens_seen = max_steps * tokens_per_micro_batch
shuffle_dataset = True
print(f"Model parameters: {num_params:,}")
print(f"Chinchilla optimal tokens: {chinchilla_tokens:,} ({chinchilla_tokens/1e9:.2f}B)")
print(f"Target tokens per batch: {target_tokens_per_batch:,}")
print(f"Tokens per micro batch: {tokens_per_micro_batch:,}")
print(f"Gradient accumulation steps: {grad_accum_steps}")
print(f"Effective batch size: {effective_batch_size} ({effective_tokens_per_batch:,} tokens)")
print(f"Calculated loop steps: {max_steps:,}")
print(f"Optimizer updates: {num_optimizer_updates:,}")
print(f"Total tokens in training: {total_tokens_seen:,} ({total_tokens_seen/1e9:.2f}B)")
print(f"Chinchilla ratio: {total_tokens_seen/chinchilla_tokens:.2f}x")
# Initialize wandb (only if saving is enabled)
if should_save and use_wandb:
wandb.init(
project="igptv4-training",
save_code=True,
config={
"vocab_size": config.vocab_size,
"n_layer": config.n_layer,
"n_head": config.n_head,
"n_embd": config.n_embd,
"n_kv_heads": config.n_kv_heads,
"mlp_ratio": config.mlp_ratio,
"micro_batch_size": micro_batch_size,
"sequence_length": sequence_length,
"grad_accum_steps": grad_accum_steps,
"effective_batch_size": effective_batch_size,
"effective_tokens_per_batch": effective_tokens_per_batch,
"target_tokens_per_batch": target_tokens_per_batch,
"max_steps": max_steps,
"max_learning_rate": max_learning_rate,
"min_learning_rate": min_learning_rate,
"weight_decay": weight_decay,
"warmup_steps": warmup_steps,
"num_params": num_params,
"chinchilla_optimal_tokens": chinchilla_tokens,
"total_tokens": total_tokens_seen,
"shuffle_dataset": shuffle_dataset,
}
)
# Initialize optimizer
optimizer = optim.AdamW(learning_rate=max_learning_rate, weight_decay=weight_decay)
# Initialize data loaders
print("Loading dataset...")
# train_loader = QaReasoningSeedTextDataLoader(micro_batch_size, sequence_length, 'synth', start_column=0, shuffle=shuffle_dataset)
# validation_loader = QaReasoningSeedTextDataLoader(micro_batch_size, sequence_length, 'synth/validate', start_column=0)
train_loader = FineWebDataLoader(micro_batch_size, sequence_length, 'edu_fineweb10B', SimpleDataLoaderBackend.Mlx, start_column=0, shuffle=shuffle_dataset)
validation_loader = FineWebDataLoader(micro_batch_size, sequence_length, 'edu_fineweb10B/validate', SimpleDataLoaderBackend.Mlx, start_column=0)
# Pre-encode first batch
train_loader.encode()
print(f"\nStarting training for {max_steps} steps...")
print(f"Micro batch size: {micro_batch_size}, Sequence length: {sequence_length}")
print(f"Gradient accumulation steps: {grad_accum_steps}")
print(f"Effective batch size: {effective_batch_size} ({effective_tokens_per_batch:,} tokens)")
print(f"Learning rate: {max_learning_rate} -> {min_learning_rate} (warmup: {warmup_steps} steps)")
print(f"Weight decay: {weight_decay}")
print(f"Log interval: {log_interval}, Val interval: {val_interval}, Save interval: {save_interval}, Inference interval {infer_interval}\n")
print("Step, Loss, Perplexity, LR, Tokens/sec, Time, ETA, GradNorm, ValLoss, ValPpl")
start_time = time.time()
global_start_time = start_time
running_loss = 0.0
accumulated_grads = None
steps_since_log = 0
total_tokens_processed = 0
last_checkpoint_path = None # Track last checkpoint for deletion
running_grad_norm = 0.0 # Track gradient norms for logging
last_grad_health = None # Store last gradient health metrics
for step in range(max_steps):
step_start = time.time()
# Memory limit check
if mx.get_active_memory() > memory_limit:
print(f"\nMemory limit exceeded at step {step}!")
print(f"Active memory: {mx.get_active_memory() / 1024**3:.2f}GB")
print(f"Memory limit: {memory_limit / 1024**3:.2f}GB")
break
# Update learning rate with warmup and cosine decay
current_lr = get_lr(mx.array(step), warmup_steps, max_steps, max_learning_rate, min_learning_rate)
optimizer.learning_rate = current_lr.item()
# Get batch
inputs, targets = train_loader.next_batch()
# Pre-encode next batch while training on current
train_loader.encode()
# Training step with gradient accumulation
loss, accumulated_grads = train_step(model, optimizer, inputs, targets, accumulated_grads)
# Force evaluation of accumulated grads every step to prevent memory buildup
mx.eval(accumulated_grads)
running_loss += loss.item()
steps_since_log += 1
total_tokens_processed += tokens_per_micro_batch
# Update weights after accumulating gradients
if (step + 1) % grad_accum_steps == 0:
# Average the accumulated gradients
scale = 1.0 / grad_accum_steps
averaged_grads = tree_map(lambda g: g * scale, accumulated_grads)
mx.eval(averaged_grads)
# Compute gradient health metrics periodically
# if (((step + 1) / grad_accum_steps)+1) % grad_health_interval == 0:
# last_grad_health = compute_gradient_health_metrics(averaged_grads)
#
# print(f"\n=== Gradient Health Check at Step {step + 1} ===")
# print(f" Grad Norm: {last_grad_health['grad_norm']:.6f}")
# print(f" Grad Mean: {last_grad_health['grad_mean']:.6e}")
# print(f" Grad Std: {last_grad_health['grad_std']:.6e}")
# print(f" Grad Max: {last_grad_health['grad_max']:.6e}")
# print(f" Grad Min: {last_grad_health['grad_min']:.6e}")
#
# if 'grad_norm_layer_first' in last_grad_health:
# print(f" First Layer Norm: {last_grad_health['grad_norm_layer_first']:.6f}")
# print(f" Last Layer Norm: {last_grad_health['grad_norm_layer_last']:.6f}")
# print(f" Ratio (Last/First): {last_grad_health['grad_ratio_last_to_first']:.4f}")
#
# # Diagnostic messages
# if last_grad_health['grad_ratio_last_to_first'] < 0.01:
# print(" ⚠️ WARNING: Severe vanishing gradients detected!")
# elif last_grad_health['grad_ratio_last_to_first'] > 100:
# print(" ⚠️ WARNING: Exploding gradients detected!")
# else:
# print(" ✓ Gradient flow looks healthy")
#
# print("=" * 50 + "\n")
#
# # Log to wandb
# if should_save and use_wandb:
# wandb.log({
# "grad_health/norm": last_grad_health['grad_norm'],
# "grad_health/mean": last_grad_health['grad_mean'],
# "grad_health/std": last_grad_health['grad_std'],
# "grad_health/max": last_grad_health['grad_max'],
# "grad_health/min": last_grad_health['grad_min'],
# "grad_health/norm_layer_first": last_grad_health.get('grad_norm_layer_first', 0),
# "grad_health/norm_layer_last": last_grad_health.get('grad_norm_layer_last', 0),
# "grad_health/ratio_last_to_first": last_grad_health.get('grad_ratio_last_to_first', 0),
# "step": step + 1,
# "tokens": total_tokens_processed,
# })
# Apply gradient clipping by global norm
flat_grads = tree_flatten(averaged_grads, destination={})
grad_norm = mx.sqrt(sum(mx.sum(g * g) for g in flat_grads.values()))
mx.eval(grad_norm)
grad_norm_value = float(grad_norm.item())
running_grad_norm += grad_norm_value
if grad_norm_value > grad_clip_norm:
scale = grad_clip_norm / grad_norm_value
averaged_grads = tree_map(lambda g: g * scale, averaged_grads)
mx.eval(averaged_grads)
optimizer.update(model, averaged_grads)
mx.eval(model.parameters())
# Clear accumulated gradients
accumulated_grads = None
# Validation
val_loss_str = ""
val_ppl_str = ""
if (step + 1) % val_interval == 0:
val_loss, val_ppl = validate(model, validation_loader)
val_loss_str = f"{val_loss:.4f}"
val_ppl_str = f"{val_ppl:.2f}"
# Convert to Python floats before logging to prevent wandb from holding references
if should_save and use_wandb:
wandb.log({
"val/loss": float(val_loss),
"val/perplexity": float(val_ppl),
"step": step + 1,
"tokens": total_tokens_processed,
})
if (step + 1) % infer_interval == 0:
text = infer(model)
print('Inference sample: ', text)
# Logging
if (step + 1) % log_interval == 0:
avg_loss = running_loss / steps_since_log
elapsed = time.time() - start_time
tokens_per_sec = (tokens_per_micro_batch * steps_since_log) / elapsed
# Calculate ETA based on actual time per step (including all overheads)
total_elapsed = time.time() - global_start_time
avg_time_per_step = total_elapsed / (step + 1)
remaining_steps = max_steps - (step + 1)
eta_seconds = avg_time_per_step * remaining_steps
# Compute perplexity
ppl = math.exp(avg_loss)
current_lr = float(optimizer.learning_rate)
# Average gradient norm over optimizer updates since last log
num_updates_since_log = steps_since_log // grad_accum_steps
avg_grad_norm = running_grad_norm / max(num_updates_since_log, 1)
print(f"{step + 1}/{max_steps}, "
f"{avg_loss:.4f}, "
f"{ppl:.2f}, "
f"{current_lr:.2e}, "
f"{tokens_per_sec:.0f}, "
f"{elapsed:.1f}s, "
f"{format_time(eta_seconds)}, "
f"{avg_grad_norm:.3f}, "
f"{val_loss_str}, "
f"{val_ppl_str}")
# Convert all values to Python types before logging to wandb
if should_save and use_wandb:
wandb.log({
"train/loss": float(avg_loss),
"train/perplexity": float(ppl),
"train/learning_rate": float(current_lr),
"train/tokens_per_sec": float(tokens_per_sec),
"train/grad_norm": float(avg_grad_norm),
"step": int(step + 1),
"tokens": int(total_tokens_processed),
})
running_loss = 0.0
running_grad_norm = 0.0
steps_since_log = 0
start_time = time.time()
# Save checkpoint
if should_save and (step + 1) % save_interval == 0:
# Delete previous checkpoint if overwrite mode is enabled
if overwrite_checkpoints and last_checkpoint_path is not None:
if os.path.exists(last_checkpoint_path):
os.remove(last_checkpoint_path)
print(f"Deleted previous checkpoint: {last_checkpoint_path}")
opt_path = last_checkpoint_path.replace('.safetensors', '_optimizer.safetensors')
if os.path.exists(opt_path):
os.remove(opt_path)
# Save new checkpoint
checkpoint_path = os.path.join(output_dir, f"checkpoint_step_{step + 1}.safetensors")
optimizer_path = os.path.join(output_dir, f"checkpoint_step_{step + 1}_optimizer.safetensors")
print(f"Saving checkpoint to {checkpoint_path}...")
model.save_weights(checkpoint_path)
# Save optimizer state using tree_flatten
print(f"Saving optimizer state to {optimizer_path}...")
optimizer_state = tree_flatten(optimizer.state, destination={})
mx.save_safetensors(optimizer_path, optimizer_state)
last_checkpoint_path = checkpoint_path
print("\nTraining complete!")
# Save final model (only if saving is enabled)
if should_save:
final_model_path = os.path.join(output_dir, "model_final.safetensors")
final_optimizer_path = os.path.join(output_dir, "model_final_optimizer.safetensors")
print(f"Saving final model to {final_model_path}...")
model.save_weights(final_model_path)
print(f"Saving final optimizer state to {final_optimizer_path}...")
optimizer_state = tree_flatten(optimizer.state, destination={})
mx.save_safetensors(final_optimizer_path, optimizer_state)
print("Model and optimizer saved!")
else:
print("Saving disabled - no final model saved")
if should_save and use_wandb:
wandb.finish()
if __name__ == "__main__":
main()
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment