Created
December 28, 2025 23:25
-
-
Save abatilo/3bf4d5d34fbb3db74955c7cbef1255a3 to your computer and use it in GitHub Desktop.
Tiny Shakespeare DDP Training Script for Spin Tutorial
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
| #!/usr/bin/env python3 | |
| """ | |
| Tiny Shakespeare DDP Training Example for Spin | |
| Minimal causal-masked Transformer trained on Tiny Shakespeare using PyTorch DDP. | |
| Uses character-level tokenization (~65 unique characters) for simplicity. | |
| Designed to run on multiple nodes via Spin's SyncSet orchestration. | |
| Usage with spinctl: | |
| # Push this script to your SyncSet pods | |
| spinctl push <syncset-name> | |
| # Run distributed training on all pods | |
| spinctl run <syncset-name> --all -- sh -c '/root/.local/bin/uv run \ | |
| --with requests --with torch --with numpy \ | |
| torchrun --nnodes=$SPIN_NNODES --node_rank=$SPIN_NODE_RANK \ | |
| --nproc_per_node=8 --rdzv_endpoint=$SPIN_MASTER_ADDR:29500 \ | |
| train.py' | |
| Note: Tiny Shakespeare (~1MB) downloads on first run. Add 'data/' to | |
| .spinignore to avoid pushing the cached data on subsequent code iterations. | |
| """ | |
| import os | |
| import requests | |
| import torch | |
| import torch.distributed as dist | |
| import torch.nn as nn | |
| from torch.nn.parallel import DistributedDataParallel as DDP | |
| from torch.utils.data import DataLoader, Dataset | |
| from torch.utils.data.distributed import DistributedSampler | |
| # ---- Data: Tiny Shakespeare -> character-level tokenization ---- | |
| def get_data(): | |
| """Download and tokenize Tiny Shakespeare dataset.""" | |
| data_dir = "data" | |
| input_path = os.path.join(data_dir, "input.txt") | |
| # Download if needed | |
| if not os.path.exists(input_path): | |
| os.makedirs(data_dir, exist_ok=True) | |
| url = "https://raw.githubusercontent.com/karpathy/char-rnn/master/data/tinyshakespeare/input.txt" | |
| print(f"Downloading Tiny Shakespeare from {url}...") | |
| text = requests.get(url).text | |
| with open(input_path, "w") as f: | |
| f.write(text) | |
| else: | |
| with open(input_path, "r") as f: | |
| text = f.read() | |
| # Character-level tokenization | |
| chars = sorted(list(set(text))) | |
| vocab_size = len(chars) | |
| stoi = {ch: i for i, ch in enumerate(chars)} | |
| itos = {i: ch for i, ch in enumerate(chars)} | |
| # Encode entire text | |
| ids = [stoi[c] for c in text] | |
| return torch.tensor(ids, dtype=torch.long), vocab_size, itos | |
| class LMSeqDataset(Dataset): | |
| def __init__(self, data, seq_len): | |
| self.data, self.seq_len = data, seq_len | |
| def __len__(self): | |
| return len(self.data) - self.seq_len - 1 | |
| def __getitem__(self, i): | |
| c = self.data[i : i + self.seq_len + 1] | |
| return c[:-1], c[1:] | |
| # ---- Minimal Transformer LM (decoder-only via causal mask) ---- | |
| class TransformerLM(nn.Module): | |
| def __init__( | |
| self, vocab_size, d_model=512, nhead=4, nlayer=4, dff=1024, max_len=256 | |
| ): | |
| super().__init__() | |
| self.tok = nn.Embedding(vocab_size, d_model) | |
| self.pos = nn.Embedding(max_len, d_model) | |
| layer = nn.TransformerEncoderLayer(d_model, nhead, dff, batch_first=True) | |
| self.tr = nn.TransformerEncoder(layer, nlayer) | |
| self.head = nn.Linear(d_model, vocab_size) | |
| mask = torch.triu(torch.full((max_len, max_len), float("-inf")), 1) | |
| self.register_buffer("mask", mask) | |
| def forward(self, x): | |
| B, T = x.shape | |
| p = torch.arange(T, device=x.device).unsqueeze(0) | |
| h = self.tok(x) + self.pos(p) | |
| h = self.tr(h, mask=self.mask[:T, :T]) | |
| return self.head(h) | |
| @torch.no_grad() | |
| def generate(model, idx, max_new_tokens, temperature=1.0, max_len=256): | |
| """Generate tokens autoregressively.""" | |
| model.eval() | |
| for _ in range(max_new_tokens): | |
| idx_cond = idx if idx.size(1) <= max_len else idx[:, -max_len:] | |
| logits = model(idx_cond) | |
| logits = logits[:, -1, :] / temperature | |
| probs = torch.softmax(logits, dim=-1) | |
| idx_next = torch.multinomial(probs, num_samples=1) | |
| idx = torch.cat([idx, idx_next], dim=1) | |
| return idx | |
| def sample_and_print(model, device, itos, seq_len, label="Sample"): | |
| """Generate and print a 256-character sample.""" | |
| start = torch.zeros((1, 1), dtype=torch.long, device=device) | |
| generated = generate( | |
| model, start, max_new_tokens=256, temperature=0.8, max_len=seq_len | |
| ) | |
| text = "".join([itos[i] for i in generated[0].tolist()]) | |
| print(f"\n--- {label} ---") | |
| print(text) | |
| print("--- End ---\n") | |
| def main(): | |
| # ---- DDP init (assume CUDA+NCCL always) ---- | |
| dist.init_process_group("nccl") | |
| local_rank = int(os.environ["LOCAL_RANK"]) | |
| torch.cuda.set_device(local_rank) | |
| device = torch.device("cuda", local_rank) | |
| is_main = dist.get_rank() == 0 | |
| # ---- Config (adjusted for character-level Shakespeare) ---- | |
| seq_len = 256 | |
| batch_size = 512 # can use larger batch with small vocab | |
| epochs = 4 | |
| steps_per_epoch = 10_000_000 | |
| lr = 3e-4 | |
| # ---- Data ---- | |
| if is_main: | |
| print("Loading Tiny Shakespeare + building vocab...") | |
| data, vocab_size, itos = get_data() | |
| if is_main: | |
| print(f"Dataset: {len(data):,} characters, vocab_size: {vocab_size}") | |
| ds = LMSeqDataset(data, seq_len) | |
| sampler = DistributedSampler(ds) | |
| dl = DataLoader( | |
| ds, batch_size=batch_size, sampler=sampler, num_workers=0, pin_memory=True | |
| ) | |
| # ---- Model ---- | |
| model = TransformerLM(vocab_size, max_len=seq_len).to(device) | |
| model = DDP(model, device_ids=[local_rank]) | |
| opt = torch.optim.AdamW(model.parameters(), lr=lr) | |
| loss_fn = nn.CrossEntropyLoss() | |
| # ---- Train ---- | |
| for e in range(epochs): | |
| sampler.set_epoch(e) | |
| model.train() | |
| if is_main: | |
| print(f"Epoch {e + 1}/{epochs}") | |
| for step, (x, y) in enumerate(dl): | |
| x = x.to(device, non_blocking=True) | |
| y = y.to(device, non_blocking=True) | |
| logits = model(x) # [B, T, V] | |
| B, T, V = logits.shape | |
| loss = loss_fn(logits.view(B * T, V), y.view(B * T)) | |
| opt.zero_grad(set_to_none=True) | |
| loss.backward() | |
| opt.step() | |
| if is_main and step % 10 == 0: | |
| print(f" step {step:03d} loss={loss.item():.4f}") | |
| if step >= steps_per_epoch: | |
| break | |
| # Sample after each epoch | |
| if is_main: | |
| sample_and_print( | |
| model.module, device, itos, seq_len, f"Epoch {e + 1} sample" | |
| ) | |
| if is_main: | |
| os.makedirs("ckpt", exist_ok=True) | |
| torch.save(model.module.state_dict(), "ckpt/shakespeare_transformer.pt") | |
| print("Saved ckpt/shakespeare_transformer.pt") | |
| # Final generation sample | |
| sample_and_print(model.module, device, itos, seq_len, "Final sample") | |
| dist.destroy_process_group() | |
| if __name__ == "__main__": | |
| main() |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment