Skip to content

Instantly share code, notes, and snippets.

@strnan
Created December 30, 2025 15:06
Show Gist options
  • Select an option

  • Save strnan/ea8e80afb246e644c0d97b1e7490627d to your computer and use it in GitHub Desktop.

Select an option

Save strnan/ea8e80afb246e644c0d97b1e7490627d to your computer and use it in GitHub Desktop.
Distributed training strategy submission
import math
import torch
import torch.nn.utils as nn_utils
import torch.distributed as dist
from copy import deepcopy
from dataclasses import dataclass
from torch.optim.lr_scheduler import LambdaLR
from typing import Dict, Any, Type, Optional, Union, List
from abc import ABC, abstractmethod
def all_reduce(tensor):
return dist.all_reduce(tensor, op=dist.ReduceOp.SUM)
@dataclass
class OptimSpec:
cls: Type[torch.optim.Optimizer]
kwargs: Dict[str, Any]
def build(self, model):
return self.cls(model.parameters(), **self.kwargs)
def ensure_optim_spec(
optim: Union[str, OptimSpec, None],
default: Optional[OptimSpec] = None,
**kwargs,
):
if optim is None:
return default or OptimSpec(torch.optim.AdamW, kwargs)
if isinstance(optim, OptimSpec):
return optim
if isinstance(optim, str):
name = optim.lower()
mapping = {
"adam": torch.optim.Adam,
"adamw": torch.optim.AdamW,
"sgd": torch.optim.SGD,
}
return OptimSpec(mapping[name], kwargs)
raise TypeError
class Strategy(ABC):
def __init__(self, **kwargs):
self.max_steps = 100
self.local_step = 0
self.nbytes = 0
for k, v in kwargs.items():
setattr(self, k, v)
def _init_node(self, model, rank, num_nodes):
self.model = model
self.rank = rank
self.num_nodes = num_nodes
if hasattr(self, "optim_spec"):
self.optim = self.optim_spec.build(model)
def zero_grad(self):
self.optim.zero_grad()
@abstractmethod
def step(self):
self.local_step += 1
class CommunicationModule(ABC):
@abstractmethod
def _init_node(self, model, rank, num_nodes):
pass
@abstractmethod
def communicate(self, model, rank, num_nodes, local_step):
pass
class SparseAsymDiLoCoEF(CommunicationModule):
"""
ASYMMETRIC
Worker -> Rank0 only
NO broadcast back
"""
def __init__(
self,
k_ratio=0.005,
outer_lr=0.2,
):
self.k_ratio = k_ratio
self.outer_lr = outer_lr
self.master_model = None
self.outer_optim = None
self.error_buffers = {}
def _init_node(self, model, rank, num_nodes):
for n, p in model.named_parameters():
self.error_buffers[n] = torch.zeros_like(p.data)
if rank == 0:
self.master_model = deepcopy(model).cpu()
self.outer_optim = torch.optim.SGD(
self.master_model.parameters(),
lr=self.outer_lr,
momentum=0.9,
)
def _sparsify(self, tensor, name, strategy):
ef = self.error_buffers[name]
tensor = tensor + ef
flat = tensor.view(-1)
k = max(1, int(self.k_ratio * flat.numel()))
_, idx = torch.topk(flat.abs(), k, sorted=False)
sparse = torch.zeros_like(flat)
sparse[idx] = flat[idx]
self.error_buffers[name] = (flat - sparse).view_as(tensor)
strategy.nbytes += sparse.numel() * sparse.element_size()
return sparse.view_as(tensor)
def communicate(self, model, rank, num_nodes, local_step):
if num_nodes == 1:
return
if local_step < 20:
H = 4
elif local_step < 60:
H = 12
else:
H = 40
if local_step % H != 0:
return
with torch.no_grad():
for name, p in model.named_parameters():
sparse = self._sparsify(p.data, name, self.strategy)
all_reduce(sparse)
if rank == 0:
p.data.copy_(sparse / num_nodes)
if rank == 0:
self.outer_optim.zero_grad()
for mp, p in zip(self.master_model.parameters(), model.parameters()):
mp.grad = mp.data - p.data.cpu()
self.outer_optim.step()
# ❌ NO BROADCAST BACK (INTENTIONAL)
class CommunicateOptimizeStrategy(Strategy):
def __init__(
self,
communication_modules: List[CommunicationModule],
optim_spec=None,
max_norm=1.0,
**kwargs,
):
super().__init__(**kwargs)
self.optim_spec = ensure_optim_spec(
optim_spec, OptimSpec(torch.optim.AdamW, {"lr": 2e-3})
)
self.communication_modules = communication_modules
self.max_norm = max_norm
for m in self.communication_modules:
m.strategy = self
def _init_node(self, model, rank, num_nodes):
super()._init_node(model, rank, num_nodes)
self.optim = self.optim_spec.build(model)
for m in self.communication_modules:
m._init_node(model, rank, num_nodes)
def step(self):
if self.max_norm:
nn_utils.clip_grad_norm_(self.model.parameters(), self.max_norm)
self.optim.step()
for m in self.communication_modules:
m.communicate(self.model, self.rank, self.num_nodes, self.local_step)
self.local_step += 1
class DiLoCoSparseEFBenchmark(CommunicateOptimizeStrategy):
def __init__(self, **kwargs):
comm = SparseAsymDiLoCoEF(
k_ratio=0.005,
outer_lr=0.2,
)
super().__init__(
communication_modules=[comm],
**kwargs,
)
STRATEGY = DiLoCoSparseEFBenchmark(
optim_spec=OptimSpec(torch.optim.AdamW, {"lr": 2e-3}),
max_norm=1.0,
)
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment