Last active
December 25, 2025 02:39
-
-
Save vgoklani/4b181b0cc472336f53023510bc636ef3 to your computer and use it in GitHub Desktop.
Very specific case that leads to NaNs on the newer versions of pytorch
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
| import os | |
| os.environ["PYTORCH_CUDA_ALLOC_CONF"] = "expandable_segments:True" | |
| import math | |
| import torch | |
| import torch.nn as nn | |
| import torch.nn.functional as F | |
| VOCAB = 65536 | |
| SEQ = 2048 | |
| BATCH = 32 # fails with 32, works with 16 | |
| N_LAYER = 20 | |
| N_EMBD = 1280 | |
| N_HEAD = 10 | |
| GRAD_ACCUM = 8 | |
| def norm(x): | |
| return F.rms_norm(x, (x.size(-1),)) | |
| class Attn(nn.Module): | |
| def __init__(self): | |
| super().__init__() | |
| self.qkv = nn.Linear(N_EMBD, 3 * N_EMBD, bias=False) | |
| self.proj = nn.Linear(N_EMBD, N_EMBD, bias=False) | |
| def forward(self, x): | |
| B, T, _ = x.shape | |
| qkv = self.qkv(x).view(B, T, 3, N_HEAD, N_EMBD // N_HEAD) | |
| q, k, v = qkv[:, :, 0], qkv[:, :, 1], qkv[:, :, 2] | |
| q, k = norm(q), norm(k) | |
| q, k, v = q.transpose(1, 2), k.transpose(1, 2), v.transpose(1, 2) | |
| y = F.scaled_dot_product_attention(q, k, v, is_causal=True) | |
| return self.proj(y.transpose(1, 2).contiguous().view(B, T, -1)) | |
| class Model(nn.Module): | |
| def __init__(self): | |
| super().__init__() | |
| self.wte = nn.Embedding(VOCAB, N_EMBD, dtype=torch.bfloat16) | |
| self.layers = nn.ModuleList([Attn() for _ in range(N_LAYER)]) | |
| self.head = nn.Linear(N_EMBD, VOCAB, bias=False) | |
| def forward(self, idx, tgt): | |
| x = norm(self.wte(idx)) | |
| for layer in self.layers: | |
| x = x + layer(norm(x)) | |
| logits = self.head(norm(x)).float() | |
| logits = 15 * torch.tanh(logits / 15) | |
| return F.cross_entropy(logits.view(-1, VOCAB), tgt.view(-1)) | |
| torch.manual_seed(0) | |
| model = Model().cuda() | |
| nn.init.zeros_(model.head.weight) | |
| for layer in model.layers: | |
| nn.init.zeros_(layer.proj.weight) | |
| # leads to NaNs | |
| model = torch.compile(model, dynamic=False, mode="max-autotune-no-cudagraphs", fullgraph=True) | |
| # works correctly | |
| # model = torch.compile(model, dynamic=False, fullgraph=True) | |
| # Note: @karpathy uses bf16 for AdamW, here we use fp32 | |
| opt = torch.optim.AdamW(model.parameters(), lr=0.1, fused=True) | |
| g = torch.Generator(device="cuda").manual_seed(1234) | |
| for step in range(8): | |
| for _ in range(GRAD_ACCUM): | |
| x = torch.randint(0, VOCAB, (BATCH, SEQ), device="cuda", generator=g) | |
| y = torch.randint(0, VOCAB, (BATCH, SEQ), device="cuda", generator=g) | |
| with torch.amp.autocast("cuda", torch.bfloat16): # unnecessary | |
| loss = model(x, y) | |
| if not math.isfinite(loss.item()): | |
| print(f"FAIL: NaN at step {step}") | |
| exit(1) | |
| (loss / GRAD_ACCUM).backward() | |
| torch.nn.utils.clip_grad_norm_(model.parameters(), 1.0) | |
| opt.step() | |
| model.zero_grad(set_to_none=True) | |
| print(f"step {step}: {loss.item():.4f}") | |
| print("OK") |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment