Created
December 15, 2025 00:04
-
-
Save iankronquist/34a0b8e5583b654ea84e6aa75d69e14e to your computer and use it in GitHub Desktop.
igptv4 llm training 2025-12-14
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
| ''' | |
| We don't have enough disk to unpack tokenized copies of the fineweb dataset, so tokenize as we go. | |
| ''' | |
| import os | |
| import time | |
| import tiktoken | |
| import random | |
| from enum import Enum | |
| from tqdm import tqdm | |
| import torch | |
| import typing | |
| import mlx.core | |
| from datasets import load_dataset | |
| from simple_data_loader import SimpleDataLoaderBackend | |
| class FineWebDataLoader: | |
| def __init__(self, B: int, T: int, file_path: str, backend: int, device='mps', subset='train', start_column=None, feature='text', shuffle=False): | |
| self.durations = [] | |
| self.shuffle = shuffle | |
| self.B = B | |
| self.T = T | |
| self.device = device | |
| self.backend = backend | |
| if isinstance(file_path, list): | |
| parquets = file_path | |
| else: | |
| parquets = [os.path.join(file_path, fname) for fname in os.listdir(file_path) if fname.endswith('parquet')] | |
| self.dataset: typing.List = load_dataset('parquet', data_files=parquets)[subset][feature] | |
| self.tokens = None | |
| self.tokens_pos = 0 | |
| if self.shuffle: | |
| self.shuffled_indices = (list(range(len(self.dataset)))) | |
| random.shuffle(self.shuffled_indices) | |
| self.encoder = tiktoken.get_encoding('gpt2') | |
| self.eot = self.encoder._special_tokens['<|endoftext|>'] | |
| self.epoch = 0 | |
| self.current_column = 0 | |
| self.start_column = start_column | |
| if start_column is not None: | |
| self.current_column = start_column | |
| self.current_position = 0 | |
| if self.backend == SimpleDataLoaderBackend.Torch: | |
| self.dtype = torch.int32 | |
| self.batch = torch.empty((B * T + 1), dtype=self.dtype) | |
| else: | |
| self.dtype = mlx.core.uint16 | |
| self.batch = mlx.core.zeros((B * T + 1), dtype=self.dtype) | |
| print('Batch size', B * T) | |
| #total_tokens = 10e9 | |
| print(f"1 epoch = {self.batches_per_epoch()} batches") | |
| #self.progress_bar = tqdm(total=len(self.dataset), unit='cols', desc='cols ') | |
| self.progress_bar = None | |
| self.encoded = None | |
| self.next_encoded = None | |
| #def __del__(self): | |
| # self.progress_bar.close() | |
| def batches_per_epoch(self): | |
| total_tokens = 10e9 | |
| return total_tokens / (self.B * self.T) | |
| def percent(self): | |
| return self.current_column / len(self.dataset) * 100 | |
| def get_column(self, index): | |
| if self.shuffle: | |
| index = self.shuffled_indices[index] | |
| return self.dataset[index] | |
| def encode(self): | |
| if self.next_encoded is None: | |
| t0 = time.time() | |
| tokens_list = [self.eot] + self.encoder.encode(self.get_column(self.current_column), allowed_special={'<|endoftext|>',}) | |
| if self.backend == SimpleDataLoaderBackend.Torch: | |
| self.next_encoded = torch.tensor(tokens_list, dtype=self.dtype) | |
| else: | |
| self.next_encoded = mlx.core.array(tokens_list, dtype=self.dtype) | |
| t1 = time.time() | |
| self.durations.append(t1 - t0) | |
| def get_encoded(self): | |
| if self.encoded is None: | |
| #if self.next_encoded is None: | |
| # self.encode() | |
| assert self.next_encoded is not None | |
| self.encoded = self.next_encoded | |
| self.next_encoded = None | |
| return self.encoded | |
| def reset(self): | |
| self.current_position = 0 | |
| if self.start_column is not None: | |
| self.current_column = self.start_column | |
| else: | |
| self.start_column = 0 | |
| def view(self, tensor, B, T): | |
| #match self.backend: | |
| if self.backend == SimpleDataLoaderBackend.Torch: | |
| return tensor.view(B, T).to(self.device) | |
| else: # SimpleDataLoaderBackend.Mlx: | |
| return tensor.reshape((B, T)) | |
| def next_column(self): | |
| #print('column', self.current_column) | |
| if self.current_column >= len(self.dataset): | |
| self.current_column = 0 | |
| self.epoch += 1 | |
| if self.progress_bar: | |
| self.progress_bar.reset() | |
| if self.progress_bar: | |
| self.progress_bar.update(1) | |
| #t0 = time.time() | |
| #if self.backend == SimpleDataLoaderBackend.Torch: | |
| # self.tokens = torch.tensor([self.eot] + self.encoder.encode(self.dataset[self.current_column], allowed_special={'<|endoftext|>',}), dtype=self.dtype) | |
| #else: | |
| # self.tokens = mlx.core.array([self.eot] + self.encoder.encode(self.dataset[self.current_column], allowed_special={'<|endoftext|>',}), dtype=self.dtype) | |
| #t1 = time.time() | |
| #self.durations.append(t1 - t0) | |
| self.tokens = self.get_encoded() | |
| self.tokens_pos = 0 | |
| self.current_column += 1 | |
| return self.tokens | |
| def peek_batch(self): | |
| B, T = self.B, self.T | |
| x = self.view(self.batch[:-1], B, T) | |
| y = self.view(self.batch[1:], B, T) | |
| return x, y | |
| def next_batch(self): | |
| batch_pos = 0 | |
| while batch_pos < len(self.batch): | |
| if self.tokens_pos == 0: | |
| #self.batch[batch_pos] = self.eot | |
| #batch_pos += 1 | |
| if batch_pos == len(self.batch): | |
| break | |
| self.next_column() | |
| batch_remaining = len(self.batch) - batch_pos | |
| assert self.tokens is not None | |
| tokens_remaining = len(self.tokens) - self.tokens_pos | |
| chunk_len = min(batch_remaining, tokens_remaining) | |
| self.batch[batch_pos:batch_pos+chunk_len] = self.tokens[self.tokens_pos:self.tokens_pos+chunk_len] | |
| batch_pos += chunk_len | |
| self.tokens_pos += chunk_len | |
| if self.tokens_pos >= len(self.tokens): | |
| self.tokens_pos = 0 | |
| self.tokens = None | |
| return self.peek_batch() | |
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
| # Based on the Karpathy video | |
| # Let's reproduce GPT-2 (124M) | |
| # https://www.youtube.com/watch?v=l8pRSuU81PU | |
| import dataclasses | |
| from dataclasses import dataclass | |
| import copy | |
| import math | |
| from typing import Optional, Tuple | |
| import mlx | |
| from mlx import nn | |
| from mlx.core import array as Tensor | |
| import mlx.core | |
| @dataclass | |
| class IGptV4Config: | |
| use_kv_cache: bool = False | |
| vocab_size: int = 8192 | |
| n_layer: int = 12 | |
| n_head: int = 12 # For multi-head attention | |
| n_embd: int = 768 # dmodel | |
| n_kv_heads: int = 4 | |
| dtype: mlx.core.Dtype = mlx.core.bfloat16 | |
| mlp_ratio: int = 4 | |
| def to_dict(self): | |
| copied_config = copy.copy(self) | |
| copied_config.dtype = str(copied_config.dtype) | |
| return dataclasses.asdict(copied_config) | |
| def _init_weights(name, module): | |
| if isinstance(module, nn.Linear): | |
| module.weight[:] = mlx.core.random.normal( | |
| shape=module.weight.shape, | |
| loc=0.0, | |
| scale=0.02 | |
| ) | |
| if module.get('bias') is not None: | |
| module.bias[:] = mlx.core.zeros_like(module.bias) | |
| elif isinstance(module, nn.Embedding): | |
| module.weight[:] = mlx.core.random.normal( | |
| shape=module.weight.shape, | |
| loc=0.0, | |
| scale=0.02 | |
| ) | |
| # We don't need to initialize RMSNorm because they're initialized with all 1s | |
| class Mlp(nn.Module): | |
| def __init__(self, config): | |
| super().__init__() | |
| #self.gate = nn.Linear(config.n_embd, config.n_embd * config.mlp_ratio, bias=False) | |
| self.c_fc = nn.Linear(config.n_embd, config.n_embd * config.mlp_ratio, bias=False) | |
| #self.activation = nn.SiLU() | |
| self.activation = nn.GELU(approx='fast') | |
| self.c_proj = nn.Linear(config.n_embd * config.mlp_ratio, config.n_embd, bias=False) | |
| self.c_fc.set_dtype(config.dtype) | |
| self.activation.set_dtype(config.dtype) | |
| self.c_proj.set_dtype(config.dtype) | |
| def __call__(self, x: Tensor) -> Tensor: | |
| #x1 = self.gate(x) | |
| x = self.c_fc(x) | |
| x = self.activation(x) #* x1 | |
| x = self.c_proj(x) | |
| return x | |
| class Block(nn.Module): | |
| def __init__(self, config, rope): | |
| super().__init__() | |
| self.attn = CausalSelfAttention(config, rope) | |
| self.ln_1 = nn.RMSNorm(config.n_embd) | |
| self.mlp = Mlp(config) | |
| self.ln_2 = nn.RMSNorm(config.n_embd) | |
| def __call__(self, x: Tensor) -> Tensor: | |
| x = x + self.attn(self.ln_1(x)) | |
| x = x + self.mlp(self.ln_2(x)) | |
| return x | |
| class CausalSelfAttention(nn.Module): | |
| def __init__(self, config, rope): | |
| super().__init__() | |
| self.n_embd = config.n_embd | |
| self.n_head = config.n_head | |
| self.n_kv_heads = config.n_kv_heads | |
| self.n_rep = config.n_head // self.n_kv_heads | |
| self.head_dim = config.n_embd // config.n_head | |
| out_dim = self.n_embd + 2 * self.head_dim * self.n_kv_heads | |
| self.c_attn = nn.Linear(self.n_embd, out_dim, bias=False) | |
| self.c_proj = nn.Linear(config.n_embd, config.n_embd, bias=False) | |
| self.c_proj.set_dtype(config.dtype) | |
| self.c_attn.set_dtype(config.dtype) | |
| self.k_cache = None | |
| self.v_cache = None | |
| self.use_kv_cache = config.use_kv_cache | |
| self.rope = rope | |
| self.scale = 1.0 / math.sqrt(self.head_dim) | |
| def __call__(self, x: Tensor) -> Tensor: | |
| B, T, C = x.shape | |
| qkv = self.c_attn(x) | |
| head_dim = self.n_embd // self.n_head | |
| q_size = self.n_head * head_dim | |
| kv_size = self.n_kv_heads * head_dim | |
| q = qkv[:, :, :q_size] | |
| k = qkv[:, :, q_size:(q_size+kv_size)] | |
| v = qkv[:, :, (q_size+kv_size):] | |
| # got = mlx.core.split(qkv, [q_size, kv_size, kv_size], axis=-1) | |
| # print('got', [(x.shape) for x in got], len(got)) | |
| # import pdb;pdb.set_trace() | |
| # got = mlx.core.split( | |
| # qkv, | |
| # #[self.n_embd, 2 * self.head_dim * self.n_kv_heads], | |
| # self.n_embd, | |
| # axis=2 | |
| # ) | |
| # q, k, v = got | |
| # k, v = mlx.core.split( | |
| # kv, | |
| # 2 * self.head_dim * self.n_kv_heads, | |
| # axis=2 | |
| # ) | |
| q = q.reshape(B, T, self.n_head, C // self.n_head) | |
| q = mlx.core.swapaxes(q, 1, 2) | |
| k = k.reshape(B, T, self.n_kv_heads, self.head_dim) | |
| k = mlx.core.swapaxes(k, 1, 2) | |
| v = v.reshape(B, T, self.n_kv_heads, self.head_dim) | |
| v = mlx.core.swapaxes(v, 1, 2) | |
| q = self.rope(q) | |
| k = self.rope(k) | |
| if self.use_kv_cache: | |
| if self.k_cache is None: | |
| self.k_cache = k | |
| self.v_cache = v | |
| else: | |
| self.k_cache = mlx.core.concatenate([self.k_cache, k], axis=2) | |
| self.v_cache = mlx.core.concatenate([self.v_cache, v], axis=2) | |
| k = self.k_cache | |
| v = self.v_cache | |
| y = mlx.core.fast.scaled_dot_product_attention(q, k, v, scale=self.scale, mask="causal") | |
| y = y.transpose(0, 2, 1, 3).reshape([B, T, C]) | |
| if self.use_kv_cache: | |
| y = y[:, -1, :].reshape([B, 1, C]) | |
| y = self.c_proj(y) | |
| return y | |
| def reset_key_value_cache(self, enable_key_value_cache: bool): | |
| self.k_cache = None | |
| self.v_cache = None | |
| self.use_kv_cache = enable_key_value_cache | |
| class Transformer(nn.Module): | |
| def __init__(self, config): | |
| super().__init__() | |
| self.wte = nn.Embedding(config.vocab_size, config.n_embd) | |
| rope = nn.RoPE(config.n_embd//config.n_head) | |
| self.h = nn.Sequential(*[Block(config, rope) for _ in range(config.n_layer)]) | |
| self.ln_f = nn.RMSNorm(config.n_embd) | |
| self.wte.set_dtype(config.dtype) | |
| self.ln_f.set_dtype(config.dtype) | |
| self.h.set_dtype(config.dtype) | |
| #for h in self.h: | |
| # h.set_dtype(config.dtype) | |
| class IGptV4(nn.Module): | |
| def __init__(self, config): | |
| super().__init__() | |
| self.config = config | |
| self.transformer = Transformer(config) | |
| self.lm_head = nn.Linear(config.n_embd, config.vocab_size, bias=False) | |
| self.lm_head.set_dtype(config.dtype) | |
| # In order to get the nice loss curve karpathy gets we need both of the following tricks, sharing the | |
| # token embedding weights and the output weights, as well as initializing the weights the way he does | |
| # init params | |
| self.apply_to_modules(_init_weights) | |
| self.transformer.wte.weight = self.lm_head.weight | |
| def __call__(self, ids: Tensor, targets: Optional[Tensor]=None) -> Tuple[Tensor, Optional[Tensor]]: | |
| x: Tensor = self.transformer.wte(ids) | |
| x = self.transformer.h(x) | |
| #for (i, layer) in enumerate(self.transformer.h): | |
| # x = layer(x) | |
| x = self.transformer.ln_f(x) | |
| logits = self.lm_head(x) | |
| loss = None | |
| return logits, loss | |
| def reset_key_value_cache(self, enable_key_value_cache: bool): | |
| for layer in self.transformer.h.layers: | |
| layer.attn.reset_key_value_cache(enable_key_value_cache) | |
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
| ''' | |
| I wanted to try using the pleias synth dataset. I created two data loaders, one for the synthetic seed text, and the other for the reasoning traces. | |
| I filter out non-English columns. I use a custom tokenization for think phases which I later learned was different than what Pleias uses. | |
| But I'm focusing on the pre-train phase right now. | |
| ''' | |
| import os | |
| import time | |
| import tiktoken | |
| from tqdm import tqdm | |
| import typing | |
| import mlx.core | |
| from datasets import load_dataset | |
| class QaReasoningDataLoader: | |
| def __init__(self, B: int, T: int, file_path: str, subset='train', start_column=None, tokenizer_name='gpt2_custom'): | |
| self.B = B | |
| self.T = T | |
| if isinstance(file_path, list): | |
| parquets = file_path | |
| else: | |
| parquets = [os.path.join(file_path, fname) for fname in os.listdir(file_path) if fname.endswith('parquet')] | |
| self.dataset: typing.List = load_dataset('parquet', data_files=parquets)[subset] | |
| self.tokens = None | |
| self.tokens_pos = 0 | |
| if tokenizer_name == 'gpt2_custom': | |
| gpt2_encoder = tiktoken.get_encoding('gpt2') | |
| n_vocab = gpt2_encoder.n_vocab | |
| self.extra_tokens = { | |
| "<user>": n_vocab + 1, | |
| "</user>": n_vocab + 2, | |
| "<agent>": n_vocab + 3, | |
| "</agent>": n_vocab + 4, | |
| "<think>": n_vocab + 5, | |
| "</think>": n_vocab + 6, | |
| } | |
| self.encoder = tiktoken.Encoding( | |
| name="gpt2_custom", | |
| pat_str=gpt2_encoder._pat_str, | |
| mergeable_ranks=gpt2_encoder._mergeable_ranks, | |
| special_tokens={**gpt2_encoder._special_tokens, **self.extra_tokens}, | |
| ) | |
| elif tokenizer_name == 'monad': | |
| from tokenizers import Tokenizer | |
| self.encoder = Tokenizer.from_file("monad_tokenizer.json") | |
| else: | |
| raise NotImplementedError(tokenizer_name) | |
| self.eot = self.encoder._special_tokens['<|endoftext|>'] | |
| self.epoch = 0 | |
| self.current_column = 0 | |
| self.start_column = start_column | |
| if start_column is not None: | |
| self.current_column = start_column | |
| self.current_position = 0 | |
| self.dtype = mlx.core.uint16 | |
| self.batch = mlx.core.zeros((B * T + 1), dtype=self.dtype) | |
| #print('Batch size', B * T) | |
| #total_tokens = 10e9 | |
| #print(f"1 epoch = {self.batches_per_epoch()} batches") | |
| self.progress_bar = None | |
| #self.progress_bar = tqdm(total=len(self.dataset), unit='cols', desc='cols ') | |
| self.encoded = None | |
| self.next_encoded = None | |
| #def __del__(self): | |
| # self.progress_bar.close() | |
| def batches_per_epoch(self): | |
| total_tokens = 10e9 | |
| return total_tokens / (self.B * self.T) | |
| def percent(self): | |
| return self.current_column / len(self.dataset) * 100 | |
| def column_text(self, column): | |
| query = column['query'] | |
| synthetic_reasoning = column['synthetic_reasoning'] | |
| synthetic_answer = column['synthetic_answer'] | |
| return f'''<user>{query}</user><think>{synthetic_reasoning}</think><agent>{synthetic_answer}</agent><|endoftext|>''' | |
| def encode(self): | |
| if self.next_encoded is None: | |
| t0 = time.time() | |
| #tokens_list = [self.eot] + self.encoder.encode(self.dataset[self.current_column], allowed_special={'<|endoftext|>',}) | |
| tokens_list = self.encoder.encode(self.column_text(self.dataset[self.current_column]), allowed_special='all') | |
| self.next_encoded = mlx.core.array(tokens_list, dtype=self.dtype) | |
| t1 = time.time() | |
| #self.durations.append(t1 - t0) | |
| def get_encoded(self): | |
| if self.encoded is None: | |
| #if self.next_encoded is None: | |
| # self.encode() | |
| assert self.next_encoded is not None | |
| self.encoded = self.next_encoded | |
| self.next_encoded = None | |
| return self.encoded | |
| def reset(self): | |
| self.current_position = 0 | |
| if self.start_column is not None: | |
| self.current_column = self.start_column | |
| else: | |
| self.start_column = 0 | |
| def view(self, tensor, B, T): | |
| return tensor.reshape((B, T)) | |
| def next_column(self): | |
| #print('column', self.current_column) | |
| if self.current_column >= len(self.dataset): | |
| self.current_column = 0 | |
| self.epoch += 1 | |
| if self.progress_bar: | |
| self.progress_bar.reset() | |
| #t0 = time.time() | |
| #if self.backend == SimpleDataLoaderBackend.Torch: | |
| # self.tokens = torch.tensor([self.eot] + self.encoder.encode(self.dataset[self.current_column], allowed_special={'<|endoftext|>',}), dtype=self.dtype) | |
| #else: | |
| # self.tokens = mlx.core.array([self.eot] + self.encoder.encode(self.dataset[self.current_column], allowed_special={'<|endoftext|>',}), dtype=self.dtype) | |
| #t1 = time.time() | |
| #self.durations.append(t1 - t0) | |
| self.tokens = self.get_encoded() | |
| self.tokens_pos = 0 | |
| while True: | |
| self.current_column += 1 | |
| if self.progress_bar: | |
| self.progress_bar.update(1) | |
| if self.dataset[self.current_column]['language'] == 'en': | |
| break | |
| return self.tokens | |
| def peek_batch(self): | |
| B, T = self.B, self.T | |
| x = self.view(self.batch[:-1], B, T) | |
| y = self.view(self.batch[1:], B, T) | |
| return x, y | |
| def next_batch(self): | |
| batch_pos = 0 | |
| while batch_pos < len(self.batch): | |
| if self.tokens_pos == 0: | |
| self.batch[batch_pos] = self.eot | |
| batch_pos += 1 | |
| if batch_pos == len(self.batch): | |
| break | |
| self.next_column() | |
| batch_remaining = len(self.batch) - batch_pos | |
| assert self.tokens is not None | |
| tokens_remaining = len(self.tokens) - self.tokens_pos | |
| chunk_len = min(batch_remaining, tokens_remaining) | |
| self.batch[batch_pos:batch_pos+chunk_len] = self.tokens[self.tokens_pos:self.tokens_pos+chunk_len] | |
| batch_pos += chunk_len | |
| self.tokens_pos += chunk_len | |
| if self.tokens_pos >= len(self.tokens): | |
| self.tokens_pos = 0 | |
| self.tokens = None | |
| return self.peek_batch() | |
| class QaReasoningSeedTextDataLoader(QaReasoningDataLoader): | |
| def column_text(self, column): | |
| seed_text = column['query_seed_text'] | |
| return f'''{seed_text}<|endoftext|>''' | |
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
| #!/usr/bin/env python3 | |
| # Training script for IGptV4 | |
| # This training script was mostly generated using the following model since I lost faith in my own | |
| # training script and am flailing a bit right now. However, I'm still having the same issues. | |
| # Model: Claude Sonnet 4.5, Date: 2025-12-12 | |
| import random | |
| import mlx.core as mx | |
| import mlx.nn as nn | |
| import mlx.optimizers as optim | |
| from mlx.utils import tree_map, tree_flatten | |
| from igptv4 import IGptV4, IGptV4Config | |
| from q_and_a_reasoning_data_loader import QaReasoningDataLoader, QaReasoningSeedTextDataLoader | |
| from fineweb_data_loader import FineWebDataLoader | |
| from simple_data_loader import SimpleDataLoaderBackend, SimpleDataLoader | |
| import time | |
| import tiktoken | |
| import math | |
| import wandb | |
| from datetime import datetime | |
| import os | |
| def count_params(model): | |
| """Count model parameters, avoiding double counting tied weights""" | |
| seen_weights = set() | |
| total = 0 | |
| def count_array(arr): | |
| nonlocal total, seen_weights | |
| arr_id = id(arr) | |
| if arr_id not in seen_weights: | |
| seen_weights.add(arr_id) | |
| total += arr.size | |
| def traverse(node): | |
| if isinstance(node, mx.array): | |
| count_array(node) | |
| elif isinstance(node, dict): | |
| for v in node.values(): | |
| traverse(v) | |
| elif isinstance(node, (list, tuple)): | |
| for item in node: | |
| traverse(item) | |
| traverse(model.parameters()) | |
| return total | |
| def format_time(seconds): | |
| """Format seconds into human readable time""" | |
| if seconds < 60: | |
| return f"{seconds:.0f}s" | |
| elif seconds < 3600: | |
| return f"{seconds//60:.0f}m {seconds%60:.0f}s" | |
| else: | |
| return f"{seconds//3600:.0f}h {(seconds%3600)//60:.0f}m" | |
| def get_lr(step, warmup_steps, max_steps, max_lr, min_lr): | |
| """ | |
| Learning rate schedule with warmup and cosine decay. | |
| Must be compilable by MLX (no conditionals based on step). | |
| """ | |
| # Linear warmup | |
| warmup_lr = max_lr * step / warmup_steps | |
| # Cosine decay after warmup | |
| decay_steps = max_steps - warmup_steps | |
| decay_progress = (step - warmup_steps) / decay_steps | |
| cosine_decay = 0.5 * (1.0 + mx.cos(math.pi * decay_progress)) | |
| decay_lr = min_lr + (max_lr - min_lr) * cosine_decay | |
| # Use mx.where for conditional selection (compilable) | |
| lr = mx.where(step < warmup_steps, warmup_lr, decay_lr) | |
| return lr | |
| def loss_fn(model, inputs, targets): | |
| """Compute cross-entropy loss""" | |
| logits, _ = model(inputs) | |
| B, T, V = logits.shape | |
| logits = logits.reshape(B * T, V) | |
| targets = targets.reshape(B * T) | |
| loss = nn.losses.cross_entropy(logits, targets, reduction='mean') | |
| return loss | |
| def train_step(model, optimizer, inputs, targets, accumulated_grads=None): | |
| """Single training step with gradient computation""" | |
| loss_and_grad_fn = nn.value_and_grad(model, loss_fn) | |
| loss, grads = loss_and_grad_fn(model, inputs, targets) | |
| # Force evaluation to free computation graph | |
| mx.eval(loss) | |
| # Accumulate gradients with tree structure handling | |
| if accumulated_grads is None: | |
| accumulated_grads = tree_map(lambda g: g.astype(mx.float32), grads) | |
| else: | |
| accumulated_grads = tree_map(lambda a, g: a + g, accumulated_grads, grads) | |
| # Force evaluation of accumulated gradients to free intermediate graphs | |
| mx.eval(accumulated_grads) | |
| return loss, accumulated_grads | |
| def validate(model, val_loader, num_batches=50): | |
| """Run validation and return average loss and perplexity""" | |
| total_loss = 0.0 | |
| for _ in range(num_batches): | |
| val_loader.encode() | |
| inputs, targets = val_loader.next_batch() | |
| loss = loss_fn(model, inputs, targets) | |
| mx.eval(loss) | |
| total_loss += loss.item() | |
| avg_loss = total_loss / num_batches | |
| perplexity = math.exp(avg_loss) | |
| return avg_loss, perplexity | |
| def compute_gradient_health_metrics(grads): | |
| """ | |
| Compute comprehensive gradient health metrics. | |
| Returns dict with various gradient statistics. | |
| """ | |
| flat_grads = list(tree_flatten(grads, destination={}).values()) | |
| # Global statistics | |
| grad_values = mx.concatenate([g.reshape(-1) for g in flat_grads]) | |
| metrics = { | |
| 'grad_norm': float(mx.sqrt(mx.sum(grad_values * grad_values)).item()), | |
| 'grad_mean': float(mx.mean(grad_values).item()), | |
| 'grad_std': float(mx.std(grad_values).item()), | |
| 'grad_max': float(mx.max(mx.abs(grad_values)).item()), | |
| 'grad_min': float(mx.min(mx.abs(grad_values)).item()), | |
| } | |
| # Per-layer statistics (first, middle, last layers) | |
| layer_grads = {} | |
| for key, grad in tree_flatten(grads, destination={}).items(): | |
| # Extract layer number if present | |
| if 'layers.' in key: | |
| layer_num = key.split('layers.')[1].split('.')[0] | |
| if layer_num not in layer_grads: | |
| layer_grads[layer_num] = [] | |
| layer_grads[layer_num].append(grad) | |
| if layer_grads: | |
| layer_nums = sorted([int(k) for k in layer_grads.keys()]) | |
| if layer_nums: | |
| first_layer = str(layer_nums[0]) | |
| last_layer = str(layer_nums[-1]) | |
| # First layer norm | |
| first_layer_values = mx.concatenate([g.reshape(-1) for g in layer_grads[first_layer]]) | |
| metrics['grad_norm_layer_first'] = float(mx.sqrt(mx.sum(first_layer_values * first_layer_values)).item()) | |
| # Last layer norm | |
| last_layer_values = mx.concatenate([g.reshape(-1) for g in layer_grads[last_layer]]) | |
| metrics['grad_norm_layer_last'] = float(mx.sqrt(mx.sum(last_layer_values * last_layer_values)).item()) | |
| # Ratio (useful for detecting vanishing/exploding) | |
| if metrics['grad_norm_layer_first'] > 0: | |
| metrics['grad_ratio_last_to_first'] = metrics['grad_norm_layer_last'] / metrics['grad_norm_layer_first'] | |
| else: | |
| metrics['grad_ratio_last_to_first'] = 0.0 | |
| return metrics | |
| def chinchilla_optimal_tokens(num_params): | |
| """Calculate Chinchilla-optimal token budget (20 tokens per parameter)""" | |
| return num_params * 20 | |
| gpt2_tokenizer = tiktoken.get_encoding('gpt2') | |
| extra_tokens = { | |
| "<user>": gpt2_tokenizer.n_vocab + 1, | |
| "</user>": gpt2_tokenizer.n_vocab + 2, | |
| "<agent>": gpt2_tokenizer.n_vocab + 3, | |
| "</agent>": gpt2_tokenizer.n_vocab + 4, | |
| "<think>": gpt2_tokenizer.n_vocab + 5, | |
| "</think>": gpt2_tokenizer.n_vocab + 6, | |
| } | |
| tokenizer = tiktoken.Encoding( | |
| name="gpt2_custom", | |
| pat_str=gpt2_tokenizer._pat_str, | |
| mergeable_ranks=gpt2_tokenizer._mergeable_ranks, | |
| special_tokens={**gpt2_tokenizer._special_tokens, **extra_tokens}, | |
| ) | |
| def tokenizer_decode(tokens): | |
| decoded = tokenizer.decode(tokens) | |
| return decoded | |
| def tokenizer_encode(text): | |
| decoded = tokenizer.encode(text) | |
| return decoded | |
| @mx.compile | |
| def sample_topk(logits: mx.array, temperature: float=0.5, k: int = 5) -> mx.array: | |
| logits /= temperature | |
| logits = mx.softmax(logits) | |
| args = mx.argpartition(logits, k, axis=-1)[..., -k:] | |
| random = mx.random.randint(0, k) | |
| return args[random] | |
| def infer(model, max_tokens=50, prompt='Once upon a time,'): | |
| tokens = tokenizer_encode(prompt) | |
| model.train(False) | |
| model.reset_key_value_cache(True) | |
| for _ in (range(max_tokens)): | |
| token_id_array = mx.array(tokens, dtype=mx.uint16) | |
| token_id_array = token_id_array.reshape((1, *token_id_array.shape)) | |
| result, loss = model(token_id_array, None) | |
| result = result[-1,-1] | |
| chosen = sample_topk(result) | |
| del token_id_array | |
| tokens.append(chosen.item()) | |
| try: | |
| decoded = tokenizer_decode(tokens) | |
| except KeyError as e: | |
| print(e) | |
| decoded = '' | |
| model.reset_key_value_cache(False) | |
| model.train(True) | |
| return decoded | |
| def main(): | |
| # Base hyperparameters (adjustable) | |
| micro_batch_size = 8 # Batch size per gradient step | |
| sequence_length = 1024 # sequence length | |
| # We want a random but predictable between run shuffle of the dataset. | |
| rng_seed = 42 | |
| random.seed(rng_seed) | |
| # Training configuration | |
| max_learning_rate = 1e-4 # Increased from 6e-4 due to vanishing gradients | |
| min_learning_rate = 0.1 * max_learning_rate | |
| weight_decay = 0.1 | |
| grad_clip_norm = 1.0 # Gradient clipping threshold (1.0 is standard for transformers) | |
| log_interval = 100 | |
| val_interval = 500 | |
| save_interval = 1000 | |
| infer_interval = 500 | |
| grad_health_interval = 200 # Check gradient health metrics | |
| overwrite_checkpoints = True # If True, only keep the last checkpoint to save disk space | |
| memory_limit = 20 * 1024 ** 3 # 20GB memory limit | |
| should_save = True # Set to False to disable all saving (checkpoints, wandb) for testing | |
| use_wandb = True # Set to False to disable wandb logging (overridden by should_save) | |
| # Warmup will be calculated as a fraction of max_steps later | |
| warmup_fraction = 0.05 # 5% of training for warmup | |
| # Generate timestamp for this training run | |
| run_timestamp = datetime.now().strftime("%Y%m%d_%H%M%S") | |
| # Create output directory for this run (only if saving) | |
| if should_save: | |
| output_dir = f"runs/run_{run_timestamp}" | |
| os.makedirs(output_dir, exist_ok=True) | |
| print(f"Output directory: {output_dir}") | |
| else: | |
| output_dir = None | |
| print("Saving disabled - running in test mode") | |
| # Model configuration | |
| config = IGptV4Config( | |
| vocab_size=50304, | |
| n_layer=12, | |
| n_head=12, | |
| n_embd=768, | |
| n_kv_heads=4, | |
| dtype=mx.bfloat16, | |
| mlp_ratio=4, | |
| use_kv_cache=False | |
| ) | |
| print("Initializing model...") | |
| model = IGptV4(config) | |
| mx.eval(model.parameters()) | |
| # Count parameters (accounting for tied weights) | |
| num_params = count_params(model) | |
| # Calculate Chinchilla-optimal token budget | |
| chinchilla_tokens = chinchilla_optimal_tokens(num_params) | |
| # Decide on a reasonable batch size (e.g., 0.5M tokens per batch) | |
| target_tokens_per_batch = 64 * 1024 | |
| # Calculate gradient accumulation steps | |
| tokens_per_micro_batch = micro_batch_size * sequence_length | |
| grad_accum_steps = target_tokens_per_batch // tokens_per_micro_batch | |
| # Calculate effective batch size | |
| effective_batch_size = micro_batch_size * grad_accum_steps | |
| effective_tokens_per_batch = effective_batch_size * sequence_length | |
| # Calculate number of steps needed to reach Chinchilla-optimal tokens | |
| # Each step processes one micro batch | |
| max_steps = chinchilla_tokens // tokens_per_micro_batch | |
| # Calculate warmup steps as a fraction of total steps | |
| warmup_steps = int(warmup_fraction * max_steps) | |
| # Calculate number of optimizer updates | |
| num_optimizer_updates = max_steps // grad_accum_steps | |
| # Total tokens that will be seen | |
| total_tokens_seen = max_steps * tokens_per_micro_batch | |
| shuffle_dataset = True | |
| print(f"Model parameters: {num_params:,}") | |
| print(f"Chinchilla optimal tokens: {chinchilla_tokens:,} ({chinchilla_tokens/1e9:.2f}B)") | |
| print(f"Target tokens per batch: {target_tokens_per_batch:,}") | |
| print(f"Tokens per micro batch: {tokens_per_micro_batch:,}") | |
| print(f"Gradient accumulation steps: {grad_accum_steps}") | |
| print(f"Effective batch size: {effective_batch_size} ({effective_tokens_per_batch:,} tokens)") | |
| print(f"Calculated loop steps: {max_steps:,}") | |
| print(f"Optimizer updates: {num_optimizer_updates:,}") | |
| print(f"Total tokens in training: {total_tokens_seen:,} ({total_tokens_seen/1e9:.2f}B)") | |
| print(f"Chinchilla ratio: {total_tokens_seen/chinchilla_tokens:.2f}x") | |
| # Initialize wandb (only if saving is enabled) | |
| if should_save and use_wandb: | |
| wandb.init( | |
| project="igptv4-training", | |
| save_code=True, | |
| config={ | |
| "vocab_size": config.vocab_size, | |
| "n_layer": config.n_layer, | |
| "n_head": config.n_head, | |
| "n_embd": config.n_embd, | |
| "n_kv_heads": config.n_kv_heads, | |
| "mlp_ratio": config.mlp_ratio, | |
| "micro_batch_size": micro_batch_size, | |
| "sequence_length": sequence_length, | |
| "grad_accum_steps": grad_accum_steps, | |
| "effective_batch_size": effective_batch_size, | |
| "effective_tokens_per_batch": effective_tokens_per_batch, | |
| "target_tokens_per_batch": target_tokens_per_batch, | |
| "max_steps": max_steps, | |
| "max_learning_rate": max_learning_rate, | |
| "min_learning_rate": min_learning_rate, | |
| "weight_decay": weight_decay, | |
| "warmup_steps": warmup_steps, | |
| "num_params": num_params, | |
| "chinchilla_optimal_tokens": chinchilla_tokens, | |
| "total_tokens": total_tokens_seen, | |
| "shuffle_dataset": shuffle_dataset, | |
| } | |
| ) | |
| # Initialize optimizer | |
| optimizer = optim.AdamW(learning_rate=max_learning_rate, weight_decay=weight_decay) | |
| # Initialize data loaders | |
| print("Loading dataset...") | |
| # train_loader = QaReasoningSeedTextDataLoader(micro_batch_size, sequence_length, 'synth', start_column=0, shuffle=shuffle_dataset) | |
| # validation_loader = QaReasoningSeedTextDataLoader(micro_batch_size, sequence_length, 'synth/validate', start_column=0) | |
| train_loader = FineWebDataLoader(micro_batch_size, sequence_length, 'edu_fineweb10B', SimpleDataLoaderBackend.Mlx, start_column=0, shuffle=shuffle_dataset) | |
| validation_loader = FineWebDataLoader(micro_batch_size, sequence_length, 'edu_fineweb10B/validate', SimpleDataLoaderBackend.Mlx, start_column=0) | |
| # Pre-encode first batch | |
| train_loader.encode() | |
| print(f"\nStarting training for {max_steps} steps...") | |
| print(f"Micro batch size: {micro_batch_size}, Sequence length: {sequence_length}") | |
| print(f"Gradient accumulation steps: {grad_accum_steps}") | |
| print(f"Effective batch size: {effective_batch_size} ({effective_tokens_per_batch:,} tokens)") | |
| print(f"Learning rate: {max_learning_rate} -> {min_learning_rate} (warmup: {warmup_steps} steps)") | |
| print(f"Weight decay: {weight_decay}") | |
| print(f"Log interval: {log_interval}, Val interval: {val_interval}, Save interval: {save_interval}, Inference interval {infer_interval}\n") | |
| print("Step, Loss, Perplexity, LR, Tokens/sec, Time, ETA, GradNorm, ValLoss, ValPpl") | |
| start_time = time.time() | |
| global_start_time = start_time | |
| running_loss = 0.0 | |
| accumulated_grads = None | |
| steps_since_log = 0 | |
| total_tokens_processed = 0 | |
| last_checkpoint_path = None # Track last checkpoint for deletion | |
| running_grad_norm = 0.0 # Track gradient norms for logging | |
| last_grad_health = None # Store last gradient health metrics | |
| for step in range(max_steps): | |
| step_start = time.time() | |
| # Memory limit check | |
| if mx.get_active_memory() > memory_limit: | |
| print(f"\nMemory limit exceeded at step {step}!") | |
| print(f"Active memory: {mx.get_active_memory() / 1024**3:.2f}GB") | |
| print(f"Memory limit: {memory_limit / 1024**3:.2f}GB") | |
| break | |
| # Update learning rate with warmup and cosine decay | |
| current_lr = get_lr(mx.array(step), warmup_steps, max_steps, max_learning_rate, min_learning_rate) | |
| optimizer.learning_rate = current_lr.item() | |
| # Get batch | |
| inputs, targets = train_loader.next_batch() | |
| # Pre-encode next batch while training on current | |
| train_loader.encode() | |
| # Training step with gradient accumulation | |
| loss, accumulated_grads = train_step(model, optimizer, inputs, targets, accumulated_grads) | |
| # Force evaluation of accumulated grads every step to prevent memory buildup | |
| mx.eval(accumulated_grads) | |
| running_loss += loss.item() | |
| steps_since_log += 1 | |
| total_tokens_processed += tokens_per_micro_batch | |
| # Update weights after accumulating gradients | |
| if (step + 1) % grad_accum_steps == 0: | |
| # Average the accumulated gradients | |
| scale = 1.0 / grad_accum_steps | |
| averaged_grads = tree_map(lambda g: g * scale, accumulated_grads) | |
| mx.eval(averaged_grads) | |
| # Compute gradient health metrics periodically | |
| # if (((step + 1) / grad_accum_steps)+1) % grad_health_interval == 0: | |
| # last_grad_health = compute_gradient_health_metrics(averaged_grads) | |
| # | |
| # print(f"\n=== Gradient Health Check at Step {step + 1} ===") | |
| # print(f" Grad Norm: {last_grad_health['grad_norm']:.6f}") | |
| # print(f" Grad Mean: {last_grad_health['grad_mean']:.6e}") | |
| # print(f" Grad Std: {last_grad_health['grad_std']:.6e}") | |
| # print(f" Grad Max: {last_grad_health['grad_max']:.6e}") | |
| # print(f" Grad Min: {last_grad_health['grad_min']:.6e}") | |
| # | |
| # if 'grad_norm_layer_first' in last_grad_health: | |
| # print(f" First Layer Norm: {last_grad_health['grad_norm_layer_first']:.6f}") | |
| # print(f" Last Layer Norm: {last_grad_health['grad_norm_layer_last']:.6f}") | |
| # print(f" Ratio (Last/First): {last_grad_health['grad_ratio_last_to_first']:.4f}") | |
| # | |
| # # Diagnostic messages | |
| # if last_grad_health['grad_ratio_last_to_first'] < 0.01: | |
| # print(" ⚠️ WARNING: Severe vanishing gradients detected!") | |
| # elif last_grad_health['grad_ratio_last_to_first'] > 100: | |
| # print(" ⚠️ WARNING: Exploding gradients detected!") | |
| # else: | |
| # print(" ✓ Gradient flow looks healthy") | |
| # | |
| # print("=" * 50 + "\n") | |
| # | |
| # # Log to wandb | |
| # if should_save and use_wandb: | |
| # wandb.log({ | |
| # "grad_health/norm": last_grad_health['grad_norm'], | |
| # "grad_health/mean": last_grad_health['grad_mean'], | |
| # "grad_health/std": last_grad_health['grad_std'], | |
| # "grad_health/max": last_grad_health['grad_max'], | |
| # "grad_health/min": last_grad_health['grad_min'], | |
| # "grad_health/norm_layer_first": last_grad_health.get('grad_norm_layer_first', 0), | |
| # "grad_health/norm_layer_last": last_grad_health.get('grad_norm_layer_last', 0), | |
| # "grad_health/ratio_last_to_first": last_grad_health.get('grad_ratio_last_to_first', 0), | |
| # "step": step + 1, | |
| # "tokens": total_tokens_processed, | |
| # }) | |
| # Apply gradient clipping by global norm | |
| flat_grads = tree_flatten(averaged_grads, destination={}) | |
| grad_norm = mx.sqrt(sum(mx.sum(g * g) for g in flat_grads.values())) | |
| mx.eval(grad_norm) | |
| grad_norm_value = float(grad_norm.item()) | |
| running_grad_norm += grad_norm_value | |
| if grad_norm_value > grad_clip_norm: | |
| scale = grad_clip_norm / grad_norm_value | |
| averaged_grads = tree_map(lambda g: g * scale, averaged_grads) | |
| mx.eval(averaged_grads) | |
| optimizer.update(model, averaged_grads) | |
| mx.eval(model.parameters()) | |
| # Clear accumulated gradients | |
| accumulated_grads = None | |
| # Validation | |
| val_loss_str = "" | |
| val_ppl_str = "" | |
| if (step + 1) % val_interval == 0: | |
| val_loss, val_ppl = validate(model, validation_loader) | |
| val_loss_str = f"{val_loss:.4f}" | |
| val_ppl_str = f"{val_ppl:.2f}" | |
| # Convert to Python floats before logging to prevent wandb from holding references | |
| if should_save and use_wandb: | |
| wandb.log({ | |
| "val/loss": float(val_loss), | |
| "val/perplexity": float(val_ppl), | |
| "step": step + 1, | |
| "tokens": total_tokens_processed, | |
| }) | |
| if (step + 1) % infer_interval == 0: | |
| text = infer(model) | |
| print('Inference sample: ', text) | |
| # Logging | |
| if (step + 1) % log_interval == 0: | |
| avg_loss = running_loss / steps_since_log | |
| elapsed = time.time() - start_time | |
| tokens_per_sec = (tokens_per_micro_batch * steps_since_log) / elapsed | |
| # Calculate ETA based on actual time per step (including all overheads) | |
| total_elapsed = time.time() - global_start_time | |
| avg_time_per_step = total_elapsed / (step + 1) | |
| remaining_steps = max_steps - (step + 1) | |
| eta_seconds = avg_time_per_step * remaining_steps | |
| # Compute perplexity | |
| ppl = math.exp(avg_loss) | |
| current_lr = float(optimizer.learning_rate) | |
| # Average gradient norm over optimizer updates since last log | |
| num_updates_since_log = steps_since_log // grad_accum_steps | |
| avg_grad_norm = running_grad_norm / max(num_updates_since_log, 1) | |
| print(f"{step + 1}/{max_steps}, " | |
| f"{avg_loss:.4f}, " | |
| f"{ppl:.2f}, " | |
| f"{current_lr:.2e}, " | |
| f"{tokens_per_sec:.0f}, " | |
| f"{elapsed:.1f}s, " | |
| f"{format_time(eta_seconds)}, " | |
| f"{avg_grad_norm:.3f}, " | |
| f"{val_loss_str}, " | |
| f"{val_ppl_str}") | |
| # Convert all values to Python types before logging to wandb | |
| if should_save and use_wandb: | |
| wandb.log({ | |
| "train/loss": float(avg_loss), | |
| "train/perplexity": float(ppl), | |
| "train/learning_rate": float(current_lr), | |
| "train/tokens_per_sec": float(tokens_per_sec), | |
| "train/grad_norm": float(avg_grad_norm), | |
| "step": int(step + 1), | |
| "tokens": int(total_tokens_processed), | |
| }) | |
| running_loss = 0.0 | |
| running_grad_norm = 0.0 | |
| steps_since_log = 0 | |
| start_time = time.time() | |
| # Save checkpoint | |
| if should_save and (step + 1) % save_interval == 0: | |
| # Delete previous checkpoint if overwrite mode is enabled | |
| if overwrite_checkpoints and last_checkpoint_path is not None: | |
| if os.path.exists(last_checkpoint_path): | |
| os.remove(last_checkpoint_path) | |
| print(f"Deleted previous checkpoint: {last_checkpoint_path}") | |
| opt_path = last_checkpoint_path.replace('.safetensors', '_optimizer.safetensors') | |
| if os.path.exists(opt_path): | |
| os.remove(opt_path) | |
| # Save new checkpoint | |
| checkpoint_path = os.path.join(output_dir, f"checkpoint_step_{step + 1}.safetensors") | |
| optimizer_path = os.path.join(output_dir, f"checkpoint_step_{step + 1}_optimizer.safetensors") | |
| print(f"Saving checkpoint to {checkpoint_path}...") | |
| model.save_weights(checkpoint_path) | |
| # Save optimizer state using tree_flatten | |
| print(f"Saving optimizer state to {optimizer_path}...") | |
| optimizer_state = tree_flatten(optimizer.state, destination={}) | |
| mx.save_safetensors(optimizer_path, optimizer_state) | |
| last_checkpoint_path = checkpoint_path | |
| print("\nTraining complete!") | |
| # Save final model (only if saving is enabled) | |
| if should_save: | |
| final_model_path = os.path.join(output_dir, "model_final.safetensors") | |
| final_optimizer_path = os.path.join(output_dir, "model_final_optimizer.safetensors") | |
| print(f"Saving final model to {final_model_path}...") | |
| model.save_weights(final_model_path) | |
| print(f"Saving final optimizer state to {final_optimizer_path}...") | |
| optimizer_state = tree_flatten(optimizer.state, destination={}) | |
| mx.save_safetensors(final_optimizer_path, optimizer_state) | |
| print("Model and optimizer saved!") | |
| else: | |
| print("Saving disabled - no final model saved") | |
| if should_save and use_wandb: | |
| wandb.finish() | |
| if __name__ == "__main__": | |
| main() |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment