Created
January 2, 2026 18:59
-
-
Save strnan/b53f8ce497f4caed2570b50630580c48 to your computer and use it in GitHub Desktop.
Distributed training strategy submission
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
| import math | |
| import torch | |
| import torch.distributed as dist | |
| import torch.nn.utils as nn_utils | |
| from copy import deepcopy | |
| from dataclasses import dataclass | |
| from typing import Type, Dict, Any, Iterable, Union | |
| from abc import ABC, abstractmethod | |
| from torch.optim.lr_scheduler import LambdaLR | |
| ParamsT = Union[Iterable[torch.Tensor], Iterable[dict[str, Any]]] | |
| @dataclass | |
| class OptimSpec: | |
| cls: Type[torch.optim.Optimizer] | |
| kwargs: Dict[str, Any] | |
| def build(self, model): | |
| return self.cls(model.parameters(), **self.kwargs) | |
| class Strategy(ABC): | |
| def __init__(self, lr_scheduler=None, lr_scheduler_kwargs=None): | |
| self.lr_scheduler = lr_scheduler | |
| self.lr_scheduler_kwargs = lr_scheduler_kwargs or {} | |
| self.scheduler = None | |
| self.max_steps = 1 | |
| def _init_node(self, model, rank, num_nodes): | |
| self.model = model | |
| self.rank = rank | |
| self.num_nodes = num_nodes | |
| self.local_step = 0 | |
| def zero_grad(self): | |
| self.optim.zero_grad() | |
| def _setup_scheduler(self): | |
| def lr_lambda(step): | |
| warmup = self.lr_scheduler_kwargs.get("warmup_steps", 1) | |
| max_steps = self.lr_scheduler_kwargs.get("max_steps", self.max_steps) | |
| if step < warmup: | |
| return step / max(1, warmup) | |
| progress = (step - warmup) / max(1, max_steps - warmup) | |
| return 0.5 * (1.0 + math.cos(math.pi * progress)) | |
| if self.lr_scheduler == "lambda_cosine": | |
| self.scheduler = LambdaLR(self.optim, lr_lambda) | |
| @abstractmethod | |
| def step(self): | |
| self.local_step += 1 | |
| class CommunicationModule(ABC): | |
| @abstractmethod | |
| def _init_node(self, model, rank, num_nodes): | |
| pass | |
| @abstractmethod | |
| def communicate(self, model, rank, num_nodes, local_step): | |
| pass | |
| class CommunicateOptimizeStrategy(Strategy): | |
| def __init__(self, communication_modules, optim_spec, max_norm=None, **kwargs): | |
| super().__init__(**kwargs) | |
| self.communication_modules = communication_modules | |
| self.optim_spec = optim_spec | |
| self.max_norm = max_norm | |
| for m in self.communication_modules: | |
| m.strategy = self | |
| def _init_node(self, model, rank, num_nodes): | |
| super()._init_node(model, rank, num_nodes) | |
| self.optim = self.optim_spec.build(model) | |
| self._setup_scheduler() | |
| for m in self.communication_modules: | |
| m._init_node(model, rank, num_nodes) | |
| def step(self): | |
| if self.max_norm: | |
| nn_utils.clip_grad_norm_(self.model.parameters(), self.max_norm) | |
| self.optim.step() | |
| for m in self.communication_modules: | |
| m.communicate(self.model, self.rank, self.num_nodes, self.local_step) | |
| if self.scheduler: | |
| self.scheduler.step() | |
| self.local_step += 1 | |
| class DiLoCoEMACommunicator(CommunicationModule): | |
| def __init__(self, H=64, ema=0.9): | |
| self.H = H | |
| self.ema = ema | |
| def _init_node(self, model, rank, num_nodes): | |
| self.shadow = [p.detach().clone() for p in model.parameters()] | |
| self.buffer = torch.zeros(1, device=next(model.parameters()).device) | |
| @torch.no_grad() | |
| def communicate(self, model, rank, num_nodes, local_step): | |
| if num_nodes <= 1: | |
| return | |
| if local_step == 0 or local_step % self.H != 0: | |
| return | |
| self.buffer.fill_(1.0) | |
| dist.all_reduce(self.buffer, op=dist.ReduceOp.SUM) | |
| scale = 1.0 / float(num_nodes) | |
| for s, p in zip(self.shadow, model.parameters()): | |
| dist.all_reduce(p.data, op=dist.ReduceOp.SUM) | |
| p.data.mul_(scale) | |
| s.mul_(self.ema).add_(p.data, alpha=1.0 - self.ema) | |
| p.data.copy_(s) | |
| class DiLoCoStrategy(CommunicateOptimizeStrategy): | |
| def __init__(self, optim_spec, H=64, ema=0.9, **kwargs): | |
| self.comm = DiLoCoEMACommunicator(H=H, ema=ema) | |
| super().__init__( | |
| communication_modules=[self.comm], | |
| optim_spec=optim_spec, | |
| **kwargs, | |
| ) | |
| STRATEGY = DiLoCoStrategy( | |
| optim_spec=OptimSpec( | |
| torch.optim.AdamW, | |
| {"lr": 0.001}, | |
| ), | |
| lr_scheduler="lambda_cosine", | |
| lr_scheduler_kwargs={ | |
| "warmup_steps": 800, | |
| "max_steps": 100, | |
| }, | |
| max_norm=1.5, | |
| H=64, | |
| ema=0.9, | |
| ) |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment