Skip to content

Instantly share code, notes, and snippets.

@vgoklani
Last active December 25, 2025 02:39
Show Gist options
  • Select an option

  • Save vgoklani/4b181b0cc472336f53023510bc636ef3 to your computer and use it in GitHub Desktop.

Select an option

Save vgoklani/4b181b0cc472336f53023510bc636ef3 to your computer and use it in GitHub Desktop.
Very specific case that leads to NaNs on the newer versions of pytorch
import os
os.environ["PYTORCH_CUDA_ALLOC_CONF"] = "expandable_segments:True"
import math
import torch
import torch.nn as nn
import torch.nn.functional as F
VOCAB = 65536
SEQ = 2048
BATCH = 32 # fails with 32, works with 16
N_LAYER = 20
N_EMBD = 1280
N_HEAD = 10
GRAD_ACCUM = 8
def norm(x):
return F.rms_norm(x, (x.size(-1),))
class Attn(nn.Module):
def __init__(self):
super().__init__()
self.qkv = nn.Linear(N_EMBD, 3 * N_EMBD, bias=False)
self.proj = nn.Linear(N_EMBD, N_EMBD, bias=False)
def forward(self, x):
B, T, _ = x.shape
qkv = self.qkv(x).view(B, T, 3, N_HEAD, N_EMBD // N_HEAD)
q, k, v = qkv[:, :, 0], qkv[:, :, 1], qkv[:, :, 2]
q, k = norm(q), norm(k)
q, k, v = q.transpose(1, 2), k.transpose(1, 2), v.transpose(1, 2)
y = F.scaled_dot_product_attention(q, k, v, is_causal=True)
return self.proj(y.transpose(1, 2).contiguous().view(B, T, -1))
class Model(nn.Module):
def __init__(self):
super().__init__()
self.wte = nn.Embedding(VOCAB, N_EMBD, dtype=torch.bfloat16)
self.layers = nn.ModuleList([Attn() for _ in range(N_LAYER)])
self.head = nn.Linear(N_EMBD, VOCAB, bias=False)
def forward(self, idx, tgt):
x = norm(self.wte(idx))
for layer in self.layers:
x = x + layer(norm(x))
logits = self.head(norm(x)).float()
logits = 15 * torch.tanh(logits / 15)
return F.cross_entropy(logits.view(-1, VOCAB), tgt.view(-1))
torch.manual_seed(0)
model = Model().cuda()
nn.init.zeros_(model.head.weight)
for layer in model.layers:
nn.init.zeros_(layer.proj.weight)
# leads to NaNs
model = torch.compile(model, dynamic=False, mode="max-autotune-no-cudagraphs", fullgraph=True)
# works correctly
# model = torch.compile(model, dynamic=False, fullgraph=True)
# Note: @karpathy uses bf16 for AdamW, here we use fp32
opt = torch.optim.AdamW(model.parameters(), lr=0.1, fused=True)
g = torch.Generator(device="cuda").manual_seed(1234)
for step in range(8):
for _ in range(GRAD_ACCUM):
x = torch.randint(0, VOCAB, (BATCH, SEQ), device="cuda", generator=g)
y = torch.randint(0, VOCAB, (BATCH, SEQ), device="cuda", generator=g)
with torch.amp.autocast("cuda", torch.bfloat16): # unnecessary
loss = model(x, y)
if not math.isfinite(loss.item()):
print(f"FAIL: NaN at step {step}")
exit(1)
(loss / GRAD_ACCUM).backward()
torch.nn.utils.clip_grad_norm_(model.parameters(), 1.0)
opt.step()
model.zero_grad(set_to_none=True)
print(f"step {step}: {loss.item():.4f}")
print("OK")
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment