Created
December 28, 2025 08:51
-
-
Save strnan/f1d3cb49657645c69872412326353d08 to your computer and use it in GitHub Desktop.
Distributed training strategy submission
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
| import math | |
| import torch | |
| import torch.nn.utils as nn_utils | |
| import torch.distributed as dist | |
| from copy import deepcopy | |
| from dataclasses import dataclass | |
| from torch.optim.lr_scheduler import LambdaLR | |
| from typing import List, Type, Union, Optional, Dict, Any | |
| from abc import ABC, abstractmethod | |
| from exogym.aux.utils import LogModule | |
| def broadcast(tensor, src=0): | |
| return dist.broadcast(tensor, src=src) | |
| def all_reduce(tensor, op=dist.ReduceOp.SUM): | |
| return dist.all_reduce(tensor, op=op) | |
| @dataclass | |
| class OptimSpec: | |
| cls: Type[torch.optim.Optimizer] | |
| kwargs: Dict[str, Any] | |
| def build(self, model): | |
| return self.cls(model.parameters(), **self.kwargs) | |
| def ensure_optim_spec( | |
| optim: Union[str, "OptimSpec", None], | |
| default: Optional["OptimSpec"] = None, | |
| **kwargs, | |
| ) -> "OptimSpec": | |
| if optim is None: | |
| return default or OptimSpec(torch.optim.AdamW, kwargs) | |
| if isinstance(optim, OptimSpec): | |
| return optim | |
| if isinstance(optim, str): | |
| name = optim.lower() | |
| mapping = { | |
| "adam": torch.optim.Adam, | |
| "adamw": torch.optim.AdamW, | |
| "sgd": torch.optim.SGD, | |
| } | |
| return OptimSpec(mapping[name], kwargs) | |
| raise TypeError | |
| class Strategy(ABC, LogModule): | |
| def __init__(self, lr_scheduler=None, lr_scheduler_kwargs=None, **kwargs): | |
| self.lr_scheduler = lr_scheduler | |
| self.lr_scheduler_kwargs = lr_scheduler_kwargs or {} | |
| self.scheduler = None | |
| self.lr_callbacks = [] | |
| self.max_steps = 1 | |
| for k, v in kwargs.items(): | |
| setattr(self, k, v) | |
| def _init_node(self, model, rank, num_nodes): | |
| self.model = model | |
| self.rank = rank | |
| self.num_nodes = num_nodes | |
| self.local_step = 0 | |
| self.nbytes = 0 | |
| if hasattr(self, "optim_spec"): | |
| self.optim = self.optim_spec.build(model) | |
| def zero_grad(self): | |
| self.optim.zero_grad() | |
| def _setup_scheduler(self): | |
| def lr_lambda(step): | |
| warmup = self.lr_scheduler_kwargs.get("warmup_steps", 1) | |
| max_steps = min( | |
| self.lr_scheduler_kwargs.get("max_steps", self.max_steps), | |
| self.max_steps, | |
| ) | |
| cosine = self.lr_scheduler_kwargs.get("cosine_anneal", True) | |
| if step < warmup: | |
| return step / max(1, warmup) | |
| if cosine: | |
| progress = (step - warmup) / max(1, max_steps - warmup) | |
| return 0.5 * (1 + math.cos(math.pi * progress)) | |
| return 1.0 | |
| if self.lr_scheduler == "lambda_cosine": | |
| self.scheduler = LambdaLR(self.optim, lr_lambda) | |
| @abstractmethod | |
| def step(self): | |
| if self.scheduler: | |
| self.scheduler.step() | |
| self.local_step += 1 | |
| class CommunicationModule(ABC): | |
| @abstractmethod | |
| def _init_node(self, model, rank, num_nodes): | |
| pass | |
| @abstractmethod | |
| def communicate(self, model, rank, num_nodes, local_step): | |
| pass | |
| class CommunicateOptimizeStrategy(Strategy): | |
| def __init__( | |
| self, | |
| communication_modules: List[CommunicationModule], | |
| optim_spec=None, | |
| max_norm=None, | |
| **kwargs, | |
| ): | |
| super().__init__(**kwargs) | |
| self.optim_spec = ensure_optim_spec( | |
| optim_spec, OptimSpec(torch.optim.AdamW, {"lr": 1e-3}) | |
| ) | |
| self.communication_modules = communication_modules | |
| self.max_norm = max_norm | |
| for m in self.communication_modules: | |
| m.strategy = self | |
| def _init_node(self, model, rank, num_nodes): | |
| super()._init_node(model, rank, num_nodes) | |
| self.optim = self.optim_spec.build(model) | |
| for m in self.communication_modules: | |
| m._init_node(model, rank, num_nodes) | |
| self._setup_scheduler() | |
| def step(self): | |
| if self.max_norm: | |
| nn_utils.clip_grad_norm_(self.model.parameters(), self.max_norm) | |
| self.optim.step() | |
| for m in self.communication_modules: | |
| m.communicate(self.model, self.rank, self.num_nodes, self.local_step) | |
| super().step() | |
| class SparseDiLoCoCommunicator(CommunicationModule): | |
| def __init__( | |
| self, | |
| H: int = 15, | |
| k_ratio: float = 0.01, | |
| outer_optim_spec: Optional[OptimSpec] = None, | |
| ): | |
| self.H = H | |
| self.k_ratio = k_ratio | |
| self.outer_optim_spec = outer_optim_spec or OptimSpec( | |
| torch.optim.SGD, {"lr": 0.7, "momentum": 0.9} | |
| ) | |
| self.master_model = None | |
| self.outer_optimizer = None | |
| self.error_buffers = {} | |
| self.strategy = None | |
| def _init_node(self, model, rank, num_nodes): | |
| if rank == 0: | |
| self.master_model = deepcopy(model).cpu() | |
| self.outer_optimizer = self.outer_optim_spec.build(self.master_model) | |
| for name, p in model.named_parameters(): | |
| self.error_buffers[name] = torch.zeros_like(p.data) | |
| def _sparsify(self, tensor, name): | |
| ef = self.error_buffers[name] | |
| tensor = tensor + ef | |
| flat = tensor.view(-1) | |
| k = max(1, int(self.k_ratio * flat.numel())) | |
| _, idx = torch.topk(flat.abs(), k, sorted=False) | |
| sparse = torch.zeros_like(flat) | |
| sparse[idx] = flat[idx] | |
| self.error_buffers[name] = (flat - sparse).view_as(tensor) | |
| if self.strategy: | |
| self.strategy.nbytes += sparse.numel() * sparse.element_size() | |
| return sparse.view_as(tensor) | |
| def communicate(self, model, rank, num_nodes, local_step): | |
| if num_nodes == 1 or local_step == 0 or local_step % self.H != 0: | |
| return | |
| with torch.no_grad(): | |
| for name, p in model.named_parameters(): | |
| sparse = self._sparsify(p.data, name) | |
| all_reduce(sparse) | |
| p.data.copy_(sparse / num_nodes) | |
| if rank == 0: | |
| self.outer_optimizer.zero_grad() | |
| for mp, p in zip( | |
| self.master_model.parameters(), | |
| model.parameters(), | |
| ): | |
| mp.grad = mp.data - p.data.cpu() | |
| self.outer_optimizer.step() | |
| for p in model.parameters(): | |
| broadcast(p.data, src=0) | |
| class DiLoCoStrategy(CommunicateOptimizeStrategy): | |
| def __init__( | |
| self, | |
| optim_spec=None, | |
| outer_optim_spec=None, | |
| H: int = 15, | |
| **kwargs, | |
| ): | |
| comm = SparseDiLoCoCommunicator( | |
| H=H, | |
| k_ratio=0.01, | |
| outer_optim_spec=outer_optim_spec, | |
| ) | |
| super().__init__( | |
| communication_modules=[comm], | |
| optim_spec=optim_spec, | |
| **kwargs, | |
| ) | |
| STRATEGY = DiLoCoStrategy( | |
| optim_spec=OptimSpec(torch.optim.AdamW, {"lr": 1e-3}), | |
| lr_scheduler="lambda_cosine", | |
| lr_scheduler_kwargs={ | |
| "warmup_steps": 500, | |
| "cosine_anneal": True, | |
| "max_steps": 100, | |
| }, | |
| max_norm=1.0, | |
| H=15, | |
| ) |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment