Hello there! I’m trying to train a custom LLM similar to Andrej Karpathy’s nanogpt and nanochat tutorials. My issue is that training loss and gradient norms go to nearly zero after around a hundred steps. I’m using the MLX framework on an M1 Max.
Code, raw logs, graphs of the training loss and validation loss, and gradient norms and raw csv data are all available on this github gist: https://gist.github.com/iankronquist/68bc7e51178aef47dd225074e5310814#file-trainingruninfo-md
I have a rather llama like architecture with rope. Unlike llama I am using gelu (like gpt2) instead of swiglu in the MLP to save a few parameters on the gate matrices. I’m using a embedding dimension of 768 and 12 layers, and an mlp up projection ratio of 4, and group query attention with a key value head ratio of 4 (all like gpt2 small and llama). I’m using the gpt2 tokenizer with a vocab dimension of 50304. This comes out to around 114M parameters and seems like I’m on the beaten path for small LLMs.
Chinchilla scaling leads me to believe I should train over 2.28b parameters and at 7.7k tokens per second I think that should take a bit over 80 hours on my laptop. However, due to the collapse in gradient norms, a total collapse of training loss and and a rise in validation loss I keep killing the training runs after 2-8 hours to tweak things like the learning rate and batch size.
I have had success training on the tinystories dataset, but I’ve had more trouble with more diverse datasets like subsets of fine web edu 10b, pleias synth, and a subset of Wikipedia.
I’ve done several runs with different datasets and different learning rates, batch sizes, and other tricks. I’ve asked Claude Sonnet, gpt5.2 and Gemini all for help but they deliver inconsistent and occasionally contradictory advice. At this point I’m looking for suggestions from an actual expert as I’m out of ideas.
I’d appreciate any thoughts feedback or suggestions and can expand on anything or provide more logs or data, and can run limited sweeps.
- The model is part of this gist in
igptv4.py - The training script is in
claude_train3.py - The data loader is in
fineweb_data_loader.py
We can see the the norm of gradients starts low, peaks high and noisy, and then falls to nearly zero (0.08 and falling) before 100 steps. I could be wrong, but I think that this means it learns very little for every additional sample. I think this means my gradients have well and truly vanished.
Training loss starts near log_e(50304) = 10.8, and falls below 0.01 in under 100 steps.
Validation loss starts under 10 and climbs to just under 10.8. I think this means that early on it learns to repeat the most frequent tokens in the dataset (eg "of of of") and then later starts memorizing the dataset in fragments. I tried reducing the learning rate from 1e-3 to 1e-4 and reducing the total batch size from 512*1024 to 64*1024 tokens, but that didn't seem to help.
Here we see our cosing decay learning rate. We will train over ~2.28B tokens, and warm up for 5% of that. Our gradient norms have collapsed well before we finish our warmup.
# Base hyperparameters (adjustable)
micro_batch_size = 8 # Batch size per gradient step
sequence_length = 1024 # sequence length
# We want a random but predictable between run shuffle of the dataset.
rng_seed = 42
random.seed(rng_seed)
# Training configuration
max_learning_rate = 1e-4 # Increased from 6e-4 due to vanishing gradients
min_learning_rate = 0.1 * max_learning_rate
weight_decay = 0.1
grad_clip_norm = 1.0 # Gradient clipping threshold (1.0 is standard for transformers)
# ...
# Model configuration
config = IGptV4Config(
vocab_size=50304,
n_layer=12,
n_head=12,
n_embd=768,
n_kv_heads=4,
dtype=mx.bfloat16,
mlp_ratio=4,
use_kv_cache=False
)
# Decide on a reasonable batch size (e.g., 0.5M tokens per batch)
target_tokens_per_batch = 64 * 1024
# Calculate gradient accumulation steps
tokens_per_micro_batch = micro_batch_size * sequence_length
grad_accum_steps = target_tokens_per_batch // tokens_per_micro_batch
# Calculate effective batch size
effective_batch_size = micro_batch_size * grad_accum_steps
effective_tokens_per_batch = effective_batch_size * sequence_length
# Calculate number of steps needed to reach Chinchilla-optimal tokens
# Each step processes one micro batch
max_steps = chinchilla_tokens // tokens_per_micro_batch
# Calculate warmup steps as a fraction of total steps
warmup_steps = int(warmup_fraction * max_steps)
# Calculate number of optimizer updates
num_optimizer_updates = max_steps // grad_accum_steps
# Total tokens that will be seen
total_tokens_seen = max_steps * tokens_per_micro_batch
# Initialize optimizer
optimizer = optim.AdamW(learning_rate=max_learning_rate, weight_decay=weight_decay)tarting training for 278686 steps...
2025-12-14 01:07:47
Micro batch size: 8, Sequence length: 1024
2025-12-14 01:07:47
Gradient accumulation steps: 8
2025-12-14 01:07:47
Effective batch size: 64 (65,536 tokens)
2025-12-14 01:07:47
Learning rate: 0.0001 -> 1e-05 (warmup: 13934 steps)
2025-12-14 01:07:47
Weight decay: 0.1
2025-12-14 01:07:47
Log interval: 100, Val interval: 500, Save interval: 1000, Inference interval 500
2025-12-14 01:07:47
2025-12-14 01:07:47
Step, Loss, Perplexity, LR, Tokens/sec, Time, ETA, GradNorm, ValLoss, ValPpl
2025-12-14 01:10:10
100/278686, 10.7878, 48424.54, 7.10e-07, 5712, 143.4s, 110h 59m, 6.271, ,
2025-12-14 01:12:36
200/278686, 9.9379, 20699.86, 1.43e-06, 5636, 145.4s, 111h 41m, 3.950, ,
2025-12-14 01:15:01
300/278686, 9.2725, 10641.04, 2.15e-06, 5651, 145.0s, 111h 48m, 2.812, ,
2025-12-14 01:17:26
400/278686, 8.5748, 5296.58, 2.86e-06, 5647, 145.1s, 111h 51m, 3.391, ,
2025-12-14 01:20:05
Inference sample: Once upon a time, ofsssssss- to to to to to to to to to to to to to to to to to to,, the, and and and and and and and and and and and and and and the, the, the
2025-12-14 01:20:05
500/278686, 7.4223, 1672.84, 3.58e-06, 5136, 159.5s, 114h 6m, 3.849, 9.9402, 20747.81
2025-12-14 01:22:31
600/278686, 6.1542, 470.71, 4.30e-06, 5636, 145.4s, 113h 45m, 5.542, ,
2025-12-14 01:24:55
700/278686, 5.1426, 171.15, 5.02e-06, 5660, 144.7s, 113h 26m, 7.100, ,
2025-12-14 01:27:20
800/278686, 4.4050, 81.86, 5.73e-06, 5654, 144.9s, 113h 12m, 8.522, ,
2025-12-14 01:29:45
900/278686, 3.8433, 46.68, 6.45e-06, 5667, 144.6s, 112h 59m, 9.376, ,
2025-12-14 01:32:26
Inference sample: Once upon a time,. and and and to to. of for Sudan for Sudan for it of Sudan. It of’ for. It for it of�����������������������
2025-12-14 01:32:26
1000/278686, 3.3757, 29.25, 7.17e-06, 5079, 161.3s, 114h 5m, 10.885, 10.0700, 23624.26
And later:
2025-12-14 16:08:00
Inference sample: Once upon a time, which is as well similarly because safe as well in despite as all a economic equally poses equally true as all are� a prime transition� growingterrorism a constructiveise of economic hardships hit which Sudan for which Sudanese population living difficult shocks months by for which
2025-12-14 16:08:00
36500/278686, 0.0017, 1.00, 9.84e-05, 5085, 161.1s, 99h 33m, 0.006, 10.7390, 46118.28
2025-12-14 16:10:28
36600/278686, 0.0017, 1.00, 9.84e-05, 5565, 147.2s, 99h 30m, 0.007, ,
2025-12-14 16:12:53
36700/278686, 0.0018, 1.00, 9.84e-05, 5622, 145.7s, 99h 27m, 0.006, ,
2025-12-14 16:15:20
36800/278686, 0.0017, 1.00, 9.84e-05, 5598, 146.3s, 99h 25m, 0.007, ,
2025-12-14 16:17:46
36900/278686, 0.0017, 1.00, 9.83e-05, 5606, 146.1s, 99h 22m, 0.006, ,
2025-12-14 16:20:28
Inference sample: Once upon a time, and as as president as a a more people for the international a constructive state it well well billion be a much makes acts makes acts as all easy steps removal made worse taken as president months diminish facto andils similarly Ham threat resilient resilient resilient resilient an unsustainable
2025-12-14 16:20:28
37000/278686, 0.0017, 1.00, 9.83e-05, 5050, 162.2s, 99h 21m, 0.007, 10.7856, 48319.02
2025-12-14 16:20:28
Deleted previous checkpoint: runs/run_20251214_010741/checkpoint_step_36000.safetensors
2025-12-14 16:20:28
Saving checkpoint to runs/run_20251214_010741/checkpoint_step_37000.safetensors...
2025-12-14 16:20:29
Saving optimizer state to runs/run_20251214_010741/checkpoint_step_37000_optimizer.safetensors...
2025-12-14 16:22:56
37100/278686, 0.0017, 1.00, 9.83e-05, 5526, 148.3s, 99h 19m, 0.006, ,
2025-12-14 16:25:22
37200/278686, 0.0017, 1.00, 9.83e-05, 5615, 145.9s, 99h 16m, 0.007, ,
Some info reported via wandb
{
"_wandb": {
"value": {
"e": {
"gpd7nsdetguq72lz5u5xxwydmwlp0xjv": {
"os": "macOS-15.3-arm64-arm-64bit-Mach-O",
"git": {
"commit": "59239a83cb20904c7ce8633d63fa2462abf57523",
"remote": "<censored>"
},
"disk": {
"/": {
"used": "964334108672",
"total": "994662584320"
}
},
"host": "ianksj316.local",
"root": "/Users/ian/gg/learnllm/gpt2-repro",
"apple": {
"name": "Apple M1 Max",
"gpuCores": 32,
"memoryGb": 32,
"ecpuCores": 2,
"pcpuCores": 8,
"ramTotalBytes": "34359738368",
"swapTotalBytes": "1073741824"
},
"email": "<censored>",
"memory": {
"total": "34359738368"
},
"python": "CPython 3.13.7",
"program": "<censored>claude_train3.py",
"codePath": "gpt2-repro/claude_train3.py",
"writerId": "<censored>",
"cpu_count": 10,
"startedAt": "2025-12-14T09:07:41.921373Z",
"executable": <censored>",
"codePathLocal": "claude_train3.py",
"cpu_count_logical": 10
}
},
"m": [],
"t": {
"1": [
1,
49,
51
],
"2": [
1,
49,
51
],
"3": [
16
],
"4": "3.13.7",
"5": "0.22.2",
"12": "0.22.2",
"13": "darwin-arm64"
},
"code_path": "code/gpt2-repro/claude_train3.py",
"cli_version": "0.22.2",
"python_version": "3.13.7"
}
},
"n_embd": {
"value": 768
},
"n_head": {
"value": 12
},
"n_layer": {
"value": 12
},
"max_steps": {
"value": 278686
},
"mlp_ratio": {
"value": 4
},
"n_kv_heads": {
"value": 4
},
"num_params": {
"value": 114150144
},
"vocab_size": {
"value": 50304
},
"total_tokens": {
"value":G 2282995712
},
"warmup_steps": {
"value": 13934
},
"weight_decay": {
"value": 0.1
},
"sequence_length": {
"value": 1024
},
"shuffle_dataset": {
"value": true
},
"grad_accum_steps": {
"value": 8
},
"micro_batch_size": {
"value": 8
},
"max_learning_rate": {
"value": 0.0001
},
"min_learning_rate": {
"value": 0.00001
},
"effective_batch_size": {
"value": 64
},
"target_tokens_per_batch": {
"value": 65536
},
"chinchilla_optimal_tokens": {
"value": 2283002880
},
"effective_tokens_per_batch": {
"value": 65536
}
}CSV of gradient norm over time CSV of training loss over time

