Skip to content

Instantly share code, notes, and snippets.

@strnan
Created December 28, 2025 08:51
Show Gist options
  • Select an option

  • Save strnan/f1d3cb49657645c69872412326353d08 to your computer and use it in GitHub Desktop.

Select an option

Save strnan/f1d3cb49657645c69872412326353d08 to your computer and use it in GitHub Desktop.
Distributed training strategy submission
import math
import torch
import torch.nn.utils as nn_utils
import torch.distributed as dist
from copy import deepcopy
from dataclasses import dataclass
from torch.optim.lr_scheduler import LambdaLR
from typing import List, Type, Union, Optional, Dict, Any
from abc import ABC, abstractmethod
from exogym.aux.utils import LogModule
def broadcast(tensor, src=0):
return dist.broadcast(tensor, src=src)
def all_reduce(tensor, op=dist.ReduceOp.SUM):
return dist.all_reduce(tensor, op=op)
@dataclass
class OptimSpec:
cls: Type[torch.optim.Optimizer]
kwargs: Dict[str, Any]
def build(self, model):
return self.cls(model.parameters(), **self.kwargs)
def ensure_optim_spec(
optim: Union[str, "OptimSpec", None],
default: Optional["OptimSpec"] = None,
**kwargs,
) -> "OptimSpec":
if optim is None:
return default or OptimSpec(torch.optim.AdamW, kwargs)
if isinstance(optim, OptimSpec):
return optim
if isinstance(optim, str):
name = optim.lower()
mapping = {
"adam": torch.optim.Adam,
"adamw": torch.optim.AdamW,
"sgd": torch.optim.SGD,
}
return OptimSpec(mapping[name], kwargs)
raise TypeError
class Strategy(ABC, LogModule):
def __init__(self, lr_scheduler=None, lr_scheduler_kwargs=None, **kwargs):
self.lr_scheduler = lr_scheduler
self.lr_scheduler_kwargs = lr_scheduler_kwargs or {}
self.scheduler = None
self.lr_callbacks = []
self.max_steps = 1
for k, v in kwargs.items():
setattr(self, k, v)
def _init_node(self, model, rank, num_nodes):
self.model = model
self.rank = rank
self.num_nodes = num_nodes
self.local_step = 0
self.nbytes = 0
if hasattr(self, "optim_spec"):
self.optim = self.optim_spec.build(model)
def zero_grad(self):
self.optim.zero_grad()
def _setup_scheduler(self):
def lr_lambda(step):
warmup = self.lr_scheduler_kwargs.get("warmup_steps", 1)
max_steps = min(
self.lr_scheduler_kwargs.get("max_steps", self.max_steps),
self.max_steps,
)
cosine = self.lr_scheduler_kwargs.get("cosine_anneal", True)
if step < warmup:
return step / max(1, warmup)
if cosine:
progress = (step - warmup) / max(1, max_steps - warmup)
return 0.5 * (1 + math.cos(math.pi * progress))
return 1.0
if self.lr_scheduler == "lambda_cosine":
self.scheduler = LambdaLR(self.optim, lr_lambda)
@abstractmethod
def step(self):
if self.scheduler:
self.scheduler.step()
self.local_step += 1
class CommunicationModule(ABC):
@abstractmethod
def _init_node(self, model, rank, num_nodes):
pass
@abstractmethod
def communicate(self, model, rank, num_nodes, local_step):
pass
class CommunicateOptimizeStrategy(Strategy):
def __init__(
self,
communication_modules: List[CommunicationModule],
optim_spec=None,
max_norm=None,
**kwargs,
):
super().__init__(**kwargs)
self.optim_spec = ensure_optim_spec(
optim_spec, OptimSpec(torch.optim.AdamW, {"lr": 1e-3})
)
self.communication_modules = communication_modules
self.max_norm = max_norm
for m in self.communication_modules:
m.strategy = self
def _init_node(self, model, rank, num_nodes):
super()._init_node(model, rank, num_nodes)
self.optim = self.optim_spec.build(model)
for m in self.communication_modules:
m._init_node(model, rank, num_nodes)
self._setup_scheduler()
def step(self):
if self.max_norm:
nn_utils.clip_grad_norm_(self.model.parameters(), self.max_norm)
self.optim.step()
for m in self.communication_modules:
m.communicate(self.model, self.rank, self.num_nodes, self.local_step)
super().step()
class SparseDiLoCoCommunicator(CommunicationModule):
def __init__(
self,
H: int = 15,
k_ratio: float = 0.01,
outer_optim_spec: Optional[OptimSpec] = None,
):
self.H = H
self.k_ratio = k_ratio
self.outer_optim_spec = outer_optim_spec or OptimSpec(
torch.optim.SGD, {"lr": 0.7, "momentum": 0.9}
)
self.master_model = None
self.outer_optimizer = None
self.error_buffers = {}
self.strategy = None
def _init_node(self, model, rank, num_nodes):
if rank == 0:
self.master_model = deepcopy(model).cpu()
self.outer_optimizer = self.outer_optim_spec.build(self.master_model)
for name, p in model.named_parameters():
self.error_buffers[name] = torch.zeros_like(p.data)
def _sparsify(self, tensor, name):
ef = self.error_buffers[name]
tensor = tensor + ef
flat = tensor.view(-1)
k = max(1, int(self.k_ratio * flat.numel()))
_, idx = torch.topk(flat.abs(), k, sorted=False)
sparse = torch.zeros_like(flat)
sparse[idx] = flat[idx]
self.error_buffers[name] = (flat - sparse).view_as(tensor)
if self.strategy:
self.strategy.nbytes += sparse.numel() * sparse.element_size()
return sparse.view_as(tensor)
def communicate(self, model, rank, num_nodes, local_step):
if num_nodes == 1 or local_step == 0 or local_step % self.H != 0:
return
with torch.no_grad():
for name, p in model.named_parameters():
sparse = self._sparsify(p.data, name)
all_reduce(sparse)
p.data.copy_(sparse / num_nodes)
if rank == 0:
self.outer_optimizer.zero_grad()
for mp, p in zip(
self.master_model.parameters(),
model.parameters(),
):
mp.grad = mp.data - p.data.cpu()
self.outer_optimizer.step()
for p in model.parameters():
broadcast(p.data, src=0)
class DiLoCoStrategy(CommunicateOptimizeStrategy):
def __init__(
self,
optim_spec=None,
outer_optim_spec=None,
H: int = 15,
**kwargs,
):
comm = SparseDiLoCoCommunicator(
H=H,
k_ratio=0.01,
outer_optim_spec=outer_optim_spec,
)
super().__init__(
communication_modules=[comm],
optim_spec=optim_spec,
**kwargs,
)
STRATEGY = DiLoCoStrategy(
optim_spec=OptimSpec(torch.optim.AdamW, {"lr": 1e-3}),
lr_scheduler="lambda_cosine",
lr_scheduler_kwargs={
"warmup_steps": 500,
"cosine_anneal": True,
"max_steps": 100,
},
max_norm=1.0,
H=15,
)
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment