Skip to content

Instantly share code, notes, and snippets.

@strnan
Created December 29, 2025 04:26
Show Gist options
  • Select an option

  • Save strnan/6e85b2345e4d46672497b7210cad37f7 to your computer and use it in GitHub Desktop.

Select an option

Save strnan/6e85b2345e4d46672497b7210cad37f7 to your computer and use it in GitHub Desktop.
Distributed training strategy submission
import math
import torch
import torch.nn as nn
import torch.distributed as dist
import torch.nn.utils as nn_utils
from copy import deepcopy
from dataclasses import dataclass
from typing import Dict, Any, Type, List
from abc import ABC, abstractmethod
def setup():
if not dist.is_initialized():
dist.init_process_group("gloo")
return dist.get_rank(), dist.get_world_size()
def cleanup():
if dist.is_initialized():
dist.destroy_process_group()
def all_reduce(t):
dist.all_reduce(t, op=dist.ReduceOp.SUM)
def broadcast(t, src=0):
dist.broadcast(t, src)
@dataclass
class OptimSpec:
cls: Type[torch.optim.Optimizer]
kwargs: Dict[str, Any]
def build(self, model):
return self.cls(model.parameters(), **self.kwargs)
class Strategy(ABC):
def __init__(self, max_steps=100):
self.max_steps = max_steps
def _init_node(self, model, rank, world):
self.model = model
self.rank = rank
self.world = world
self.step_idx = 0
def zero_grad(self):
self.optim.zero_grad(set_to_none=True)
@abstractmethod
def step(self):
pass
class CommunicationModule(ABC):
@abstractmethod
def _init_node(self, model, rank, world):
pass
@abstractmethod
def communicate(self, model, rank, world, step):
pass
class GSADDiLoCo(CommunicationModule):
def __init__(self, H=10, alpha=1.3, outer_lr=0.6):
self.H = H
self.alpha = alpha
self.outer_lr = outer_lr
self.master = None
self.outer_optim = None
self.ema = None
def _init_node(self, model, rank, world):
if rank == 0:
self.master = deepcopy(model).cpu()
self.outer_optim = torch.optim.SGD(
self.master.parameters(),
lr=self.outer_lr,
momentum=0.9,
nesterov=True,
)
self.ema = torch.tensor(0.0)
def communicate(self, model, rank, world, step):
if world == 1:
return
if step == 0 or step % self.H != 0:
return
total_norm = torch.zeros(1)
for p in model.parameters():
if p.grad is not None:
total_norm += p.grad.data.norm() ** 2
total_norm = total_norm.sqrt()
all_reduce(total_norm)
total_norm /= world
if self.ema.item() == 0.0:
self.ema = total_norm.detach()
else:
self.ema = 0.9 * self.ema + 0.1 * total_norm.detach()
if total_norm < self.alpha * self.ema:
return
with torch.no_grad():
for p in model.parameters():
all_reduce(p.data)
p.data /= world
if rank == 0:
self.outer_optim.zero_grad()
for mp, p in zip(self.master.parameters(), model.parameters()):
mp.grad = mp.data - p.data.cpu()
self.outer_optim.step()
for p in model.parameters():
broadcast(p.data, src=0)
class StableStrategy(Strategy):
def __init__(self, optim_spec, comms: List[CommunicationModule], max_norm=1.0, max_steps=100):
super().__init__(max_steps=max_steps)
self.optim_spec = optim_spec
self.comms = comms
self.max_norm = max_norm
def _init_node(self, model, rank, world):
super()._init_node(model, rank, world)
self.optim = self.optim_spec.build(model)
for c in self.comms:
c._init_node(model, rank, world)
def step(self):
if self.max_norm:
nn_utils.clip_grad_norm_(self.model.parameters(), self.max_norm)
self.optim.step()
for c in self.comms:
c.communicate(self.model, self.rank, self.world, self.step_idx)
self.step_idx += 1
class ToyModel(nn.Module):
def __init__(self):
super().__init__()
self.net = nn.Sequential(
nn.Linear(128, 256),
nn.ReLU(),
nn.Linear(256, 10),
)
def forward(self, x):
return self.net(x)
def main():
rank, world = setup()
torch.manual_seed(0)
model = ToyModel()
strategy = StableStrategy(
optim_spec=OptimSpec(
torch.optim.AdamW,
{"lr": 3e-4, "weight_decay": 0.01},
),
comms=[GSADDiLoCo(H=10, alpha=1.3, outer_lr=0.6)],
max_norm=1.0,
max_steps=100,
)
strategy._init_node(model, rank, world)
loss_fn = nn.CrossEntropyLoss()
for step in range(100):
x = torch.randn(32, 128)
y = torch.randint(0, 10, (32,))
loss = loss_fn(model(x), y)
strategy.zero_grad()
loss.backward()
strategy.step()
if rank == 0 and step % 10 == 0:
print(f"[step {step:03d}] loss={loss.item():.4f}")
cleanup()
if __name__ == "__main__":
main()
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment