Skip to content

Instantly share code, notes, and snippets.

@strnan
Created December 31, 2025 10:05
Show Gist options
  • Select an option

  • Save strnan/3ac6185118e324935068c6709cd6c6b7 to your computer and use it in GitHub Desktop.

Select an option

Save strnan/3ac6185118e324935068c6709cd6c6b7 to your computer and use it in GitHub Desktop.
Distributed training strategy submission
import os
import torch
import torch.nn.utils as nn_utils
import torch.distributed as dist
from dataclasses import dataclass
from typing import Dict, Any, Type
from abc import ABC, abstractmethod
def all_reduce(t):
dist.all_reduce(t, op=dist.ReduceOp.SUM)
@dataclass
class OptimSpec:
cls: Type[torch.optim.Optimizer]
kwargs: Dict[str, Any]
def build(self, model):
return self.cls(model.parameters(), **self.kwargs)
class Strategy(ABC):
def __init__(self):
self.local_step = 0
self.max_steps = 100
self.nbytes = 0
self.device = None
def _init_node(self, model, rank, num_nodes):
self.rank = rank
self.num_nodes = num_nodes
if torch.cuda.is_available():
self.device = torch.device(f"cuda:{rank}")
torch.cuda.set_device(self.device)
else:
self.device = torch.device("cpu")
self.model = model.to(self.device)
def zero_grad(self):
self.optim.zero_grad(set_to_none=True)
@abstractmethod
def step(self):
pass
class CommunicationModule(ABC):
def _init_node(self, model, rank, num_nodes):
pass
@abstractmethod
def communicate(self, model, rank, num_nodes, local_step):
pass
class LateSyncCommunicator(CommunicationModule):
def __init__(self, sync_step=80):
self.sync_step = sync_step
self.done = False
def communicate(self, model, rank, num_nodes, local_step):
if num_nodes <= 1:
return
if local_step < self.sync_step:
return
if self.done:
return
with torch.no_grad():
for p in model.parameters():
if not p.is_cuda:
continue
all_reduce(p.data)
p.data.div_(num_nodes)
self.done = True
def _init_node(self, model, rank, num_nodes):
pass
class LocalFirstStrategy(Strategy):
def __init__(self):
super().__init__()
self.comm = LateSyncCommunicator(sync_step=80)
def _init_node(self, model, rank, num_nodes):
super()._init_node(model, rank, num_nodes)
self.optim = torch.optim.AdamW(
self.model.parameters(),
lr=3e-4,
betas=(0.9, 0.95),
weight_decay=0.01,
)
self.comm._init_node(self.model, rank, num_nodes)
def step(self):
nn_utils.clip_grad_norm_(self.model.parameters(), 1.0)
self.optim.step()
self.comm.communicate(
self.model, self.rank, self.num_nodes, self.local_step
)
self.local_step += 1
STRATEGY = LocalFirstStrategy()
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment