Skip to content

Instantly share code, notes, and snippets.

@abatilo
Created December 28, 2025 23:25
Show Gist options
  • Select an option

  • Save abatilo/3bf4d5d34fbb3db74955c7cbef1255a3 to your computer and use it in GitHub Desktop.

Select an option

Save abatilo/3bf4d5d34fbb3db74955c7cbef1255a3 to your computer and use it in GitHub Desktop.
Tiny Shakespeare DDP Training Script for Spin Tutorial
#!/usr/bin/env python3
"""
Tiny Shakespeare DDP Training Example for Spin
Minimal causal-masked Transformer trained on Tiny Shakespeare using PyTorch DDP.
Uses character-level tokenization (~65 unique characters) for simplicity.
Designed to run on multiple nodes via Spin's SyncSet orchestration.
Usage with spinctl:
# Push this script to your SyncSet pods
spinctl push <syncset-name>
# Run distributed training on all pods
spinctl run <syncset-name> --all -- sh -c '/root/.local/bin/uv run \
--with requests --with torch --with numpy \
torchrun --nnodes=$SPIN_NNODES --node_rank=$SPIN_NODE_RANK \
--nproc_per_node=8 --rdzv_endpoint=$SPIN_MASTER_ADDR:29500 \
train.py'
Note: Tiny Shakespeare (~1MB) downloads on first run. Add 'data/' to
.spinignore to avoid pushing the cached data on subsequent code iterations.
"""
import os
import requests
import torch
import torch.distributed as dist
import torch.nn as nn
from torch.nn.parallel import DistributedDataParallel as DDP
from torch.utils.data import DataLoader, Dataset
from torch.utils.data.distributed import DistributedSampler
# ---- Data: Tiny Shakespeare -> character-level tokenization ----
def get_data():
"""Download and tokenize Tiny Shakespeare dataset."""
data_dir = "data"
input_path = os.path.join(data_dir, "input.txt")
# Download if needed
if not os.path.exists(input_path):
os.makedirs(data_dir, exist_ok=True)
url = "https://raw.githubusercontent.com/karpathy/char-rnn/master/data/tinyshakespeare/input.txt"
print(f"Downloading Tiny Shakespeare from {url}...")
text = requests.get(url).text
with open(input_path, "w") as f:
f.write(text)
else:
with open(input_path, "r") as f:
text = f.read()
# Character-level tokenization
chars = sorted(list(set(text)))
vocab_size = len(chars)
stoi = {ch: i for i, ch in enumerate(chars)}
itos = {i: ch for i, ch in enumerate(chars)}
# Encode entire text
ids = [stoi[c] for c in text]
return torch.tensor(ids, dtype=torch.long), vocab_size, itos
class LMSeqDataset(Dataset):
def __init__(self, data, seq_len):
self.data, self.seq_len = data, seq_len
def __len__(self):
return len(self.data) - self.seq_len - 1
def __getitem__(self, i):
c = self.data[i : i + self.seq_len + 1]
return c[:-1], c[1:]
# ---- Minimal Transformer LM (decoder-only via causal mask) ----
class TransformerLM(nn.Module):
def __init__(
self, vocab_size, d_model=512, nhead=4, nlayer=4, dff=1024, max_len=256
):
super().__init__()
self.tok = nn.Embedding(vocab_size, d_model)
self.pos = nn.Embedding(max_len, d_model)
layer = nn.TransformerEncoderLayer(d_model, nhead, dff, batch_first=True)
self.tr = nn.TransformerEncoder(layer, nlayer)
self.head = nn.Linear(d_model, vocab_size)
mask = torch.triu(torch.full((max_len, max_len), float("-inf")), 1)
self.register_buffer("mask", mask)
def forward(self, x):
B, T = x.shape
p = torch.arange(T, device=x.device).unsqueeze(0)
h = self.tok(x) + self.pos(p)
h = self.tr(h, mask=self.mask[:T, :T])
return self.head(h)
@torch.no_grad()
def generate(model, idx, max_new_tokens, temperature=1.0, max_len=256):
"""Generate tokens autoregressively."""
model.eval()
for _ in range(max_new_tokens):
idx_cond = idx if idx.size(1) <= max_len else idx[:, -max_len:]
logits = model(idx_cond)
logits = logits[:, -1, :] / temperature
probs = torch.softmax(logits, dim=-1)
idx_next = torch.multinomial(probs, num_samples=1)
idx = torch.cat([idx, idx_next], dim=1)
return idx
def sample_and_print(model, device, itos, seq_len, label="Sample"):
"""Generate and print a 256-character sample."""
start = torch.zeros((1, 1), dtype=torch.long, device=device)
generated = generate(
model, start, max_new_tokens=256, temperature=0.8, max_len=seq_len
)
text = "".join([itos[i] for i in generated[0].tolist()])
print(f"\n--- {label} ---")
print(text)
print("--- End ---\n")
def main():
# ---- DDP init (assume CUDA+NCCL always) ----
dist.init_process_group("nccl")
local_rank = int(os.environ["LOCAL_RANK"])
torch.cuda.set_device(local_rank)
device = torch.device("cuda", local_rank)
is_main = dist.get_rank() == 0
# ---- Config (adjusted for character-level Shakespeare) ----
seq_len = 256
batch_size = 512 # can use larger batch with small vocab
epochs = 4
steps_per_epoch = 10_000_000
lr = 3e-4
# ---- Data ----
if is_main:
print("Loading Tiny Shakespeare + building vocab...")
data, vocab_size, itos = get_data()
if is_main:
print(f"Dataset: {len(data):,} characters, vocab_size: {vocab_size}")
ds = LMSeqDataset(data, seq_len)
sampler = DistributedSampler(ds)
dl = DataLoader(
ds, batch_size=batch_size, sampler=sampler, num_workers=0, pin_memory=True
)
# ---- Model ----
model = TransformerLM(vocab_size, max_len=seq_len).to(device)
model = DDP(model, device_ids=[local_rank])
opt = torch.optim.AdamW(model.parameters(), lr=lr)
loss_fn = nn.CrossEntropyLoss()
# ---- Train ----
for e in range(epochs):
sampler.set_epoch(e)
model.train()
if is_main:
print(f"Epoch {e + 1}/{epochs}")
for step, (x, y) in enumerate(dl):
x = x.to(device, non_blocking=True)
y = y.to(device, non_blocking=True)
logits = model(x) # [B, T, V]
B, T, V = logits.shape
loss = loss_fn(logits.view(B * T, V), y.view(B * T))
opt.zero_grad(set_to_none=True)
loss.backward()
opt.step()
if is_main and step % 10 == 0:
print(f" step {step:03d} loss={loss.item():.4f}")
if step >= steps_per_epoch:
break
# Sample after each epoch
if is_main:
sample_and_print(
model.module, device, itos, seq_len, f"Epoch {e + 1} sample"
)
if is_main:
os.makedirs("ckpt", exist_ok=True)
torch.save(model.module.state_dict(), "ckpt/shakespeare_transformer.pt")
print("Saved ckpt/shakespeare_transformer.pt")
# Final generation sample
sample_and_print(model.module, device, itos, seq_len, "Final sample")
dist.destroy_process_group()
if __name__ == "__main__":
main()
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment