Created
December 30, 2025 15:06
-
-
Save strnan/ea8e80afb246e644c0d97b1e7490627d to your computer and use it in GitHub Desktop.
Distributed training strategy submission
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
| import math | |
| import torch | |
| import torch.nn.utils as nn_utils | |
| import torch.distributed as dist | |
| from copy import deepcopy | |
| from dataclasses import dataclass | |
| from torch.optim.lr_scheduler import LambdaLR | |
| from typing import Dict, Any, Type, Optional, Union, List | |
| from abc import ABC, abstractmethod | |
| def all_reduce(tensor): | |
| return dist.all_reduce(tensor, op=dist.ReduceOp.SUM) | |
| @dataclass | |
| class OptimSpec: | |
| cls: Type[torch.optim.Optimizer] | |
| kwargs: Dict[str, Any] | |
| def build(self, model): | |
| return self.cls(model.parameters(), **self.kwargs) | |
| def ensure_optim_spec( | |
| optim: Union[str, OptimSpec, None], | |
| default: Optional[OptimSpec] = None, | |
| **kwargs, | |
| ): | |
| if optim is None: | |
| return default or OptimSpec(torch.optim.AdamW, kwargs) | |
| if isinstance(optim, OptimSpec): | |
| return optim | |
| if isinstance(optim, str): | |
| name = optim.lower() | |
| mapping = { | |
| "adam": torch.optim.Adam, | |
| "adamw": torch.optim.AdamW, | |
| "sgd": torch.optim.SGD, | |
| } | |
| return OptimSpec(mapping[name], kwargs) | |
| raise TypeError | |
| class Strategy(ABC): | |
| def __init__(self, **kwargs): | |
| self.max_steps = 100 | |
| self.local_step = 0 | |
| self.nbytes = 0 | |
| for k, v in kwargs.items(): | |
| setattr(self, k, v) | |
| def _init_node(self, model, rank, num_nodes): | |
| self.model = model | |
| self.rank = rank | |
| self.num_nodes = num_nodes | |
| if hasattr(self, "optim_spec"): | |
| self.optim = self.optim_spec.build(model) | |
| def zero_grad(self): | |
| self.optim.zero_grad() | |
| @abstractmethod | |
| def step(self): | |
| self.local_step += 1 | |
| class CommunicationModule(ABC): | |
| @abstractmethod | |
| def _init_node(self, model, rank, num_nodes): | |
| pass | |
| @abstractmethod | |
| def communicate(self, model, rank, num_nodes, local_step): | |
| pass | |
| class SparseAsymDiLoCoEF(CommunicationModule): | |
| """ | |
| ASYMMETRIC | |
| Worker -> Rank0 only | |
| NO broadcast back | |
| """ | |
| def __init__( | |
| self, | |
| k_ratio=0.005, | |
| outer_lr=0.2, | |
| ): | |
| self.k_ratio = k_ratio | |
| self.outer_lr = outer_lr | |
| self.master_model = None | |
| self.outer_optim = None | |
| self.error_buffers = {} | |
| def _init_node(self, model, rank, num_nodes): | |
| for n, p in model.named_parameters(): | |
| self.error_buffers[n] = torch.zeros_like(p.data) | |
| if rank == 0: | |
| self.master_model = deepcopy(model).cpu() | |
| self.outer_optim = torch.optim.SGD( | |
| self.master_model.parameters(), | |
| lr=self.outer_lr, | |
| momentum=0.9, | |
| ) | |
| def _sparsify(self, tensor, name, strategy): | |
| ef = self.error_buffers[name] | |
| tensor = tensor + ef | |
| flat = tensor.view(-1) | |
| k = max(1, int(self.k_ratio * flat.numel())) | |
| _, idx = torch.topk(flat.abs(), k, sorted=False) | |
| sparse = torch.zeros_like(flat) | |
| sparse[idx] = flat[idx] | |
| self.error_buffers[name] = (flat - sparse).view_as(tensor) | |
| strategy.nbytes += sparse.numel() * sparse.element_size() | |
| return sparse.view_as(tensor) | |
| def communicate(self, model, rank, num_nodes, local_step): | |
| if num_nodes == 1: | |
| return | |
| if local_step < 20: | |
| H = 4 | |
| elif local_step < 60: | |
| H = 12 | |
| else: | |
| H = 40 | |
| if local_step % H != 0: | |
| return | |
| with torch.no_grad(): | |
| for name, p in model.named_parameters(): | |
| sparse = self._sparsify(p.data, name, self.strategy) | |
| all_reduce(sparse) | |
| if rank == 0: | |
| p.data.copy_(sparse / num_nodes) | |
| if rank == 0: | |
| self.outer_optim.zero_grad() | |
| for mp, p in zip(self.master_model.parameters(), model.parameters()): | |
| mp.grad = mp.data - p.data.cpu() | |
| self.outer_optim.step() | |
| # ❌ NO BROADCAST BACK (INTENTIONAL) | |
| class CommunicateOptimizeStrategy(Strategy): | |
| def __init__( | |
| self, | |
| communication_modules: List[CommunicationModule], | |
| optim_spec=None, | |
| max_norm=1.0, | |
| **kwargs, | |
| ): | |
| super().__init__(**kwargs) | |
| self.optim_spec = ensure_optim_spec( | |
| optim_spec, OptimSpec(torch.optim.AdamW, {"lr": 2e-3}) | |
| ) | |
| self.communication_modules = communication_modules | |
| self.max_norm = max_norm | |
| for m in self.communication_modules: | |
| m.strategy = self | |
| def _init_node(self, model, rank, num_nodes): | |
| super()._init_node(model, rank, num_nodes) | |
| self.optim = self.optim_spec.build(model) | |
| for m in self.communication_modules: | |
| m._init_node(model, rank, num_nodes) | |
| def step(self): | |
| if self.max_norm: | |
| nn_utils.clip_grad_norm_(self.model.parameters(), self.max_norm) | |
| self.optim.step() | |
| for m in self.communication_modules: | |
| m.communicate(self.model, self.rank, self.num_nodes, self.local_step) | |
| self.local_step += 1 | |
| class DiLoCoSparseEFBenchmark(CommunicateOptimizeStrategy): | |
| def __init__(self, **kwargs): | |
| comm = SparseAsymDiLoCoEF( | |
| k_ratio=0.005, | |
| outer_lr=0.2, | |
| ) | |
| super().__init__( | |
| communication_modules=[comm], | |
| **kwargs, | |
| ) | |
| STRATEGY = DiLoCoSparseEFBenchmark( | |
| optim_spec=OptimSpec(torch.optim.AdamW, {"lr": 2e-3}), | |
| max_norm=1.0, | |
| ) |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment