Skip to content

Instantly share code, notes, and snippets.

@strnan
Created December 29, 2025 04:50
Show Gist options
  • Select an option

  • Save strnan/6d66d0e88d7c0408913c4092938159e5 to your computer and use it in GitHub Desktop.

Select an option

Save strnan/6d66d0e88d7c0408913c4092938159e5 to your computer and use it in GitHub Desktop.
Distributed training strategy submission
import os
import math
import torch
import torch.distributed as dist
import torch.nn.utils as nn_utils
from copy import deepcopy
from dataclasses import dataclass
from typing import Dict, Any, Type
from abc import ABC, abstractmethod
def init_dist():
if dist.is_initialized():
return
dist.init_process_group(backend="gloo", init_method="env://")
def get_rank():
return int(os.environ["RANK"])
def get_world():
return int(os.environ["WORLD_SIZE"])
def broadcast(t, src=0):
dist.broadcast(t, src)
def all_reduce(t):
dist.all_reduce(t)
@dataclass
class OptimSpec:
cls: Type[torch.optim.Optimizer]
kwargs: Dict[str, Any]
def build(self, model):
return self.cls(model.parameters(), **self.kwargs)
class Strategy(ABC):
def __init__(self):
self.max_steps = 100
def _init_node(self, model):
self.model = model
self.rank = get_rank()
self.world = get_world()
self.local_step = 0
@abstractmethod
def step(self):
self.local_step += 1
def zero_grad(self):
self.optim.zero_grad(set_to_none=True)
class CommunicationModule(ABC):
@abstractmethod
def _init_node(self, model):
pass
@abstractmethod
def communicate(self, model, step):
pass
class CommunicateOptimizeStrategy(Strategy):
def __init__(self, modules, optim_spec, max_norm=1.0):
super().__init__()
self.modules = modules
self.optim_spec = optim_spec
self.max_norm = max_norm
for m in modules:
m.strategy = self
def _init_node(self, model):
super()._init_node(model)
self.optim = self.optim_spec.build(model)
for m in self.modules:
m._init_node(model)
def step(self):
nn_utils.clip_grad_norm_(self.model.parameters(), self.max_norm)
self.optim.step()
for m in self.modules:
m.communicate(self.model, self.local_step)
self.local_step += 1
class GradientEnergyRouter:
def __init__(self):
pass
def ratio(self, step):
if step < 20:
return 0.03
if step < 50:
return 0.02
return 0.012
def select(self, g, step):
r = self.ratio(step)
flat = g.view(-1)
k = max(1, int(flat.numel() * r))
idx = torch.topk(flat.abs(), k).indices
mask = torch.zeros_like(flat)
mask[idx] = 1
return (flat * mask).view_as(g)
class GSERDiLoCo(CommunicationModule):
def __init__(self, base_H=12):
self.base_H = base_H
self.router = GradientEnergyRouter()
self.energy_ema = None
self.error = {}
def _init_node(self, model):
self.rank = get_rank()
self.world = get_world()
if self.rank == 0:
self.master = deepcopy(model).cpu()
self.outer_optim = torch.optim.SGD(
self.master.parameters(),
lr=0.35,
momentum=0.9
)
for n, p in model.named_parameters():
self.error[n] = torch.zeros_like(p.data)
def _energy(self, model):
e = torch.tensor(0.0, device=next(model.parameters()).device)
for p in model.parameters():
if p.grad is not None:
e += (p.grad ** 2).sum()
all_reduce(e)
return math.sqrt(e.item())
def communicate(self, model, step):
energy = self._energy(model)
if self.energy_ema is None:
self.energy_ema = energy
else:
self.energy_ema = 0.9 * self.energy_ema + 0.1 * energy
ratio = self.energy_ema / (energy + 1e-8)
ratio = max(0.5, min(2.0, ratio))
H = int(self.base_H * ratio)
if step < 30:
H = max(4, min(10, H))
else:
H = max(8, min(25, H))
if step % H != 0 or self.world == 1:
return
for name, p in model.named_parameters():
g = p.grad + self.error[name]
sparse = self.router.select(g, step)
self.error[name] = g - sparse
all_reduce(sparse)
p.grad = sparse / self.world
if self.rank == 0:
self.outer_optim.zero_grad()
for mp, lp in zip(self.master.parameters(), model.parameters()):
mp.grad = mp.data - lp.data.cpu()
self.outer_optim.step()
for p in model.parameters():
broadcast(p.data, 0)
STRATEGY = CommunicateOptimizeStrategy(
modules=[GSERDiLoCo()],
optim_spec=OptimSpec(torch.optim.AdamW, {"lr": 1e-3}),
max_norm=1.0
)
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment