Created
December 31, 2025 10:05
-
-
Save strnan/3ac6185118e324935068c6709cd6c6b7 to your computer and use it in GitHub Desktop.
Distributed training strategy submission
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
| import os | |
| import torch | |
| import torch.nn.utils as nn_utils | |
| import torch.distributed as dist | |
| from dataclasses import dataclass | |
| from typing import Dict, Any, Type | |
| from abc import ABC, abstractmethod | |
| def all_reduce(t): | |
| dist.all_reduce(t, op=dist.ReduceOp.SUM) | |
| @dataclass | |
| class OptimSpec: | |
| cls: Type[torch.optim.Optimizer] | |
| kwargs: Dict[str, Any] | |
| def build(self, model): | |
| return self.cls(model.parameters(), **self.kwargs) | |
| class Strategy(ABC): | |
| def __init__(self): | |
| self.local_step = 0 | |
| self.max_steps = 100 | |
| self.nbytes = 0 | |
| self.device = None | |
| def _init_node(self, model, rank, num_nodes): | |
| self.rank = rank | |
| self.num_nodes = num_nodes | |
| if torch.cuda.is_available(): | |
| self.device = torch.device(f"cuda:{rank}") | |
| torch.cuda.set_device(self.device) | |
| else: | |
| self.device = torch.device("cpu") | |
| self.model = model.to(self.device) | |
| def zero_grad(self): | |
| self.optim.zero_grad(set_to_none=True) | |
| @abstractmethod | |
| def step(self): | |
| pass | |
| class CommunicationModule(ABC): | |
| def _init_node(self, model, rank, num_nodes): | |
| pass | |
| @abstractmethod | |
| def communicate(self, model, rank, num_nodes, local_step): | |
| pass | |
| class LateSyncCommunicator(CommunicationModule): | |
| def __init__(self, sync_step=80): | |
| self.sync_step = sync_step | |
| self.done = False | |
| def communicate(self, model, rank, num_nodes, local_step): | |
| if num_nodes <= 1: | |
| return | |
| if local_step < self.sync_step: | |
| return | |
| if self.done: | |
| return | |
| with torch.no_grad(): | |
| for p in model.parameters(): | |
| if not p.is_cuda: | |
| continue | |
| all_reduce(p.data) | |
| p.data.div_(num_nodes) | |
| self.done = True | |
| def _init_node(self, model, rank, num_nodes): | |
| pass | |
| class LocalFirstStrategy(Strategy): | |
| def __init__(self): | |
| super().__init__() | |
| self.comm = LateSyncCommunicator(sync_step=80) | |
| def _init_node(self, model, rank, num_nodes): | |
| super()._init_node(model, rank, num_nodes) | |
| self.optim = torch.optim.AdamW( | |
| self.model.parameters(), | |
| lr=3e-4, | |
| betas=(0.9, 0.95), | |
| weight_decay=0.01, | |
| ) | |
| self.comm._init_node(self.model, rank, num_nodes) | |
| def step(self): | |
| nn_utils.clip_grad_norm_(self.model.parameters(), 1.0) | |
| self.optim.step() | |
| self.comm.communicate( | |
| self.model, self.rank, self.num_nodes, self.local_step | |
| ) | |
| self.local_step += 1 | |
| STRATEGY = LocalFirstStrategy() |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment