Skip to content

Instantly share code, notes, and snippets.

@strnan
Created January 2, 2026 18:59
Show Gist options
  • Select an option

  • Save strnan/b53f8ce497f4caed2570b50630580c48 to your computer and use it in GitHub Desktop.

Select an option

Save strnan/b53f8ce497f4caed2570b50630580c48 to your computer and use it in GitHub Desktop.
Distributed training strategy submission
import math
import torch
import torch.distributed as dist
import torch.nn.utils as nn_utils
from copy import deepcopy
from dataclasses import dataclass
from typing import Type, Dict, Any, Iterable, Union
from abc import ABC, abstractmethod
from torch.optim.lr_scheduler import LambdaLR
ParamsT = Union[Iterable[torch.Tensor], Iterable[dict[str, Any]]]
@dataclass
class OptimSpec:
cls: Type[torch.optim.Optimizer]
kwargs: Dict[str, Any]
def build(self, model):
return self.cls(model.parameters(), **self.kwargs)
class Strategy(ABC):
def __init__(self, lr_scheduler=None, lr_scheduler_kwargs=None):
self.lr_scheduler = lr_scheduler
self.lr_scheduler_kwargs = lr_scheduler_kwargs or {}
self.scheduler = None
self.max_steps = 1
def _init_node(self, model, rank, num_nodes):
self.model = model
self.rank = rank
self.num_nodes = num_nodes
self.local_step = 0
def zero_grad(self):
self.optim.zero_grad()
def _setup_scheduler(self):
def lr_lambda(step):
warmup = self.lr_scheduler_kwargs.get("warmup_steps", 1)
max_steps = self.lr_scheduler_kwargs.get("max_steps", self.max_steps)
if step < warmup:
return step / max(1, warmup)
progress = (step - warmup) / max(1, max_steps - warmup)
return 0.5 * (1.0 + math.cos(math.pi * progress))
if self.lr_scheduler == "lambda_cosine":
self.scheduler = LambdaLR(self.optim, lr_lambda)
@abstractmethod
def step(self):
self.local_step += 1
class CommunicationModule(ABC):
@abstractmethod
def _init_node(self, model, rank, num_nodes):
pass
@abstractmethod
def communicate(self, model, rank, num_nodes, local_step):
pass
class CommunicateOptimizeStrategy(Strategy):
def __init__(self, communication_modules, optim_spec, max_norm=None, **kwargs):
super().__init__(**kwargs)
self.communication_modules = communication_modules
self.optim_spec = optim_spec
self.max_norm = max_norm
for m in self.communication_modules:
m.strategy = self
def _init_node(self, model, rank, num_nodes):
super()._init_node(model, rank, num_nodes)
self.optim = self.optim_spec.build(model)
self._setup_scheduler()
for m in self.communication_modules:
m._init_node(model, rank, num_nodes)
def step(self):
if self.max_norm:
nn_utils.clip_grad_norm_(self.model.parameters(), self.max_norm)
self.optim.step()
for m in self.communication_modules:
m.communicate(self.model, self.rank, self.num_nodes, self.local_step)
if self.scheduler:
self.scheduler.step()
self.local_step += 1
class DiLoCoEMACommunicator(CommunicationModule):
def __init__(self, H=64, ema=0.9):
self.H = H
self.ema = ema
def _init_node(self, model, rank, num_nodes):
self.shadow = [p.detach().clone() for p in model.parameters()]
self.buffer = torch.zeros(1, device=next(model.parameters()).device)
@torch.no_grad()
def communicate(self, model, rank, num_nodes, local_step):
if num_nodes <= 1:
return
if local_step == 0 or local_step % self.H != 0:
return
self.buffer.fill_(1.0)
dist.all_reduce(self.buffer, op=dist.ReduceOp.SUM)
scale = 1.0 / float(num_nodes)
for s, p in zip(self.shadow, model.parameters()):
dist.all_reduce(p.data, op=dist.ReduceOp.SUM)
p.data.mul_(scale)
s.mul_(self.ema).add_(p.data, alpha=1.0 - self.ema)
p.data.copy_(s)
class DiLoCoStrategy(CommunicateOptimizeStrategy):
def __init__(self, optim_spec, H=64, ema=0.9, **kwargs):
self.comm = DiLoCoEMACommunicator(H=H, ema=ema)
super().__init__(
communication_modules=[self.comm],
optim_spec=optim_spec,
**kwargs,
)
STRATEGY = DiLoCoStrategy(
optim_spec=OptimSpec(
torch.optim.AdamW,
{"lr": 0.001},
),
lr_scheduler="lambda_cosine",
lr_scheduler_kwargs={
"warmup_steps": 800,
"max_steps": 100,
},
max_norm=1.5,
H=64,
ema=0.9,
)
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment