Created
December 29, 2025 04:26
-
-
Save strnan/6e85b2345e4d46672497b7210cad37f7 to your computer and use it in GitHub Desktop.
Distributed training strategy submission
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
| import math | |
| import torch | |
| import torch.nn as nn | |
| import torch.distributed as dist | |
| import torch.nn.utils as nn_utils | |
| from copy import deepcopy | |
| from dataclasses import dataclass | |
| from typing import Dict, Any, Type, List | |
| from abc import ABC, abstractmethod | |
| def setup(): | |
| if not dist.is_initialized(): | |
| dist.init_process_group("gloo") | |
| return dist.get_rank(), dist.get_world_size() | |
| def cleanup(): | |
| if dist.is_initialized(): | |
| dist.destroy_process_group() | |
| def all_reduce(t): | |
| dist.all_reduce(t, op=dist.ReduceOp.SUM) | |
| def broadcast(t, src=0): | |
| dist.broadcast(t, src) | |
| @dataclass | |
| class OptimSpec: | |
| cls: Type[torch.optim.Optimizer] | |
| kwargs: Dict[str, Any] | |
| def build(self, model): | |
| return self.cls(model.parameters(), **self.kwargs) | |
| class Strategy(ABC): | |
| def __init__(self, max_steps=100): | |
| self.max_steps = max_steps | |
| def _init_node(self, model, rank, world): | |
| self.model = model | |
| self.rank = rank | |
| self.world = world | |
| self.step_idx = 0 | |
| def zero_grad(self): | |
| self.optim.zero_grad(set_to_none=True) | |
| @abstractmethod | |
| def step(self): | |
| pass | |
| class CommunicationModule(ABC): | |
| @abstractmethod | |
| def _init_node(self, model, rank, world): | |
| pass | |
| @abstractmethod | |
| def communicate(self, model, rank, world, step): | |
| pass | |
| class GSADDiLoCo(CommunicationModule): | |
| def __init__(self, H=10, alpha=1.3, outer_lr=0.6): | |
| self.H = H | |
| self.alpha = alpha | |
| self.outer_lr = outer_lr | |
| self.master = None | |
| self.outer_optim = None | |
| self.ema = None | |
| def _init_node(self, model, rank, world): | |
| if rank == 0: | |
| self.master = deepcopy(model).cpu() | |
| self.outer_optim = torch.optim.SGD( | |
| self.master.parameters(), | |
| lr=self.outer_lr, | |
| momentum=0.9, | |
| nesterov=True, | |
| ) | |
| self.ema = torch.tensor(0.0) | |
| def communicate(self, model, rank, world, step): | |
| if world == 1: | |
| return | |
| if step == 0 or step % self.H != 0: | |
| return | |
| total_norm = torch.zeros(1) | |
| for p in model.parameters(): | |
| if p.grad is not None: | |
| total_norm += p.grad.data.norm() ** 2 | |
| total_norm = total_norm.sqrt() | |
| all_reduce(total_norm) | |
| total_norm /= world | |
| if self.ema.item() == 0.0: | |
| self.ema = total_norm.detach() | |
| else: | |
| self.ema = 0.9 * self.ema + 0.1 * total_norm.detach() | |
| if total_norm < self.alpha * self.ema: | |
| return | |
| with torch.no_grad(): | |
| for p in model.parameters(): | |
| all_reduce(p.data) | |
| p.data /= world | |
| if rank == 0: | |
| self.outer_optim.zero_grad() | |
| for mp, p in zip(self.master.parameters(), model.parameters()): | |
| mp.grad = mp.data - p.data.cpu() | |
| self.outer_optim.step() | |
| for p in model.parameters(): | |
| broadcast(p.data, src=0) | |
| class StableStrategy(Strategy): | |
| def __init__(self, optim_spec, comms: List[CommunicationModule], max_norm=1.0, max_steps=100): | |
| super().__init__(max_steps=max_steps) | |
| self.optim_spec = optim_spec | |
| self.comms = comms | |
| self.max_norm = max_norm | |
| def _init_node(self, model, rank, world): | |
| super()._init_node(model, rank, world) | |
| self.optim = self.optim_spec.build(model) | |
| for c in self.comms: | |
| c._init_node(model, rank, world) | |
| def step(self): | |
| if self.max_norm: | |
| nn_utils.clip_grad_norm_(self.model.parameters(), self.max_norm) | |
| self.optim.step() | |
| for c in self.comms: | |
| c.communicate(self.model, self.rank, self.world, self.step_idx) | |
| self.step_idx += 1 | |
| class ToyModel(nn.Module): | |
| def __init__(self): | |
| super().__init__() | |
| self.net = nn.Sequential( | |
| nn.Linear(128, 256), | |
| nn.ReLU(), | |
| nn.Linear(256, 10), | |
| ) | |
| def forward(self, x): | |
| return self.net(x) | |
| def main(): | |
| rank, world = setup() | |
| torch.manual_seed(0) | |
| model = ToyModel() | |
| strategy = StableStrategy( | |
| optim_spec=OptimSpec( | |
| torch.optim.AdamW, | |
| {"lr": 3e-4, "weight_decay": 0.01}, | |
| ), | |
| comms=[GSADDiLoCo(H=10, alpha=1.3, outer_lr=0.6)], | |
| max_norm=1.0, | |
| max_steps=100, | |
| ) | |
| strategy._init_node(model, rank, world) | |
| loss_fn = nn.CrossEntropyLoss() | |
| for step in range(100): | |
| x = torch.randn(32, 128) | |
| y = torch.randint(0, 10, (32,)) | |
| loss = loss_fn(model(x), y) | |
| strategy.zero_grad() | |
| loss.backward() | |
| strategy.step() | |
| if rank == 0 and step % 10 == 0: | |
| print(f"[step {step:03d}] loss={loss.item():.4f}") | |
| cleanup() | |
| if __name__ == "__main__": | |
| main() |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment