Skip to content

Instantly share code, notes, and snippets.

@strnan
Last active January 2, 2026 18:48
Show Gist options
  • Select an option

  • Save strnan/0ec7307d84fd61adf9b960ed3582e0a9 to your computer and use it in GitHub Desktop.

Select an option

Save strnan/0ec7307d84fd61adf9b960ed3582e0a9 to your computer and use it in GitHub Desktop.
Distributed training strategy submission
import math
import os
import torch
import torch.nn.utils as nn_utils
import torch.distributed as dist
import torch.fft
from einops import rearrange
import datetime
from copy import deepcopy
from dataclasses import dataclass
from torch.optim.lr_scheduler import LambdaLR
from typing import (
List,
Type,
Union,
Optional,
Dict,
Any,
TypeAlias,
Callable,
Iterable,
Tuple,
)
from abc import ABC, abstractmethod
from exogym.aux.utils import LogModule
ParamsT: TypeAlias = Union[Iterable[torch.Tensor], Iterable[dict[str, Any]]]
def mps_compatible(func):
def all_gather_wrapper(tensor_list, tensor, *args, **kwargs):
is_tensor_mps = hasattr(tensor, "device") and tensor.device.type == "mps"
is_list_mps = any(
hasattr(t, "device") and t.device.type == "mps" for t in tensor_list
)
if is_tensor_mps or is_list_mps:
cpu_tensor = tensor.data.to("cpu") if is_tensor_mps else tensor
cpu_tensor_list = [
t.data.to("cpu") if hasattr(t, "device") and t.device.type == "mps" else t
for t in tensor_list
]
result = func(cpu_tensor_list, cpu_tensor, *args, **kwargs)
if is_tensor_mps:
tensor.data.copy_(cpu_tensor.to("mps"))
for i, t in enumerate(tensor_list):
if hasattr(t, "device") and t.device.type == "mps":
t.data.copy_(cpu_tensor_list[i].to("mps"))
return result
return func(tensor_list, tensor, *args, **kwargs)
def standard_wrapper(tensor, *args, **kwargs):
if hasattr(tensor, "device") and tensor.device.type == "mps":
cpu_tensor = tensor.data.to("cpu")
result = func(cpu_tensor, *args, **kwargs)
tensor.data.copy_(cpu_tensor.to("mps"))
return result
return func(tensor, *args, **kwargs)
return all_gather_wrapper if func.__name__ == "all_gather" else standard_wrapper
@mps_compatible
def broadcast(tensor, src=0):
return dist.broadcast(tensor, src=src)
@mps_compatible
def all_reduce(tensor, op=dist.ReduceOp.SUM):
return dist.all_reduce(tensor, op=op)
@mps_compatible
def all_gather(tensor_list, tensor, group=None, async_op=False):
return dist.all_gather(tensor_list, tensor, group=group, async_op=async_op)
@dataclass
class OptimSpec:
cls: Type[torch.optim.Optimizer]
kwargs: Dict[str, Any]
def build(self, model):
return self.cls(model.parameters(), **(self.kwargs or {}))
def ensure_optim_spec(
optim: Union[str, OptimSpec, None], default: Optional[OptimSpec] = None, **kwargs
) -> OptimSpec:
if optim is None:
return default or OptimSpec(torch.optim.AdamW, kwargs)
if isinstance(optim, OptimSpec):
return optim
raise TypeError
class Strategy(ABC, LogModule):
def __init__(self, lr_scheduler=None, lr_scheduler_kwargs=None, **kwargs):
self.lr_scheduler = lr_scheduler
self.lr_scheduler_kwargs = lr_scheduler_kwargs or {}
self.kwargs = kwargs
self.scheduler = None
self.lr_callbacks = []
self.max_steps = 1
def _init_node(self, model, rank, num_nodes):
self.model = model
self.rank = rank
self.num_nodes = num_nodes
self.local_step = 0
@abstractmethod
def step(self):
self.local_step += 1
def zero_grad(self):
self.optim.zero_grad()
def _setup_scheduler(self):
def lr_lambda(step):
warmup = self.lr_scheduler_kwargs.get("warmup_steps", 1)
max_steps = self.lr_scheduler_kwargs.get("max_steps", self.max_steps)
if step < warmup:
return step / max(1, warmup)
progress = (step - warmup) / max(1, max_steps - warmup)
return 0.5 * (1.0 + math.cos(math.pi * progress))
if self.lr_scheduler == "lambda_cosine":
self.scheduler = LambdaLR(self.optim, lr_lambda)
def __config__(self):
return {"strategy": self.__class__.__name__}
class CommunicationModule(ABC):
@abstractmethod
def communicate(self, model, rank, num_nodes, local_step):
pass
@abstractmethod
def _init_node(self, model, rank, num_nodes):
pass
class CommunicateOptimizeStrategy(Strategy):
def __init__(self, communication_modules, optim_spec=None, max_norm=None, **kwargs):
super().__init__(**kwargs)
self.communication_modules = communication_modules
self.optim_spec = optim_spec
self.max_norm = max_norm
for m in self.communication_modules:
m.strategy = self
def _init_node(self, model, rank, num_nodes):
super()._init_node(model, rank, num_nodes)
self.optim = self.optim_spec.build(model)
self._setup_scheduler()
for m in self.communication_modules:
m._init_node(model, rank, num_nodes)
def step(self):
if self.max_norm:
nn_utils.clip_grad_norm_(self.model.parameters(), self.max_norm)
self.optim.step()
for m in self.communication_modules:
m.communicate(self.model, self.rank, self.num_nodes, self.local_step)
if self.scheduler:
self.scheduler.step()
self.local_step += 1
class DiLoCoCommunicator(CommunicationModule):
def __init__(self, H=25, outer_optim_spec=None):
self.H = H
self.outer_optim_spec = outer_optim_spec
def _init_node(self, model, rank, num_nodes):
self.pg = dist.new_group(backend="gloo", timeout=datetime.timedelta(60))
self.master_model = deepcopy(model).to("cpu")
for p in self.master_model.parameters():
p.requires_grad = True
self.outer_optim = self.outer_optim_spec.cls(
self.master_model.parameters(),
process_group=self.pg,
**self.outer_optim_spec.kwargs,
)
def communicate(self, model, rank, num_nodes, local_step):
if num_nodes > 1 and local_step > 0 and local_step % self.H == 0:
self.outer_optim.zero_grad()
for n, p in self.master_model.named_parameters():
p.grad = p.data - model.state_dict()[n].data.to("cpu")
self.outer_optim.step()
for n, p in model.named_parameters():
p.data.copy_(self.master_model.state_dict()[n].to(p.device))
class DiLoCoStrategy(CommunicateOptimizeStrategy):
def __init__(self, optim_spec, outer_optim_spec, H=25, **kwargs):
self.comm = DiLoCoCommunicator(H=H, outer_optim_spec=outer_optim_spec)
super().__init__(
communication_modules=[self.comm],
optim_spec=optim_spec,
**kwargs,
)
class SparseLoCo(torch.optim.SGD):
def __init__(
self,
params,
lr,
momentum=0.9,
weight_decay=0.05,
top_k=64,
chunk_size=64,
use_dct=True,
use_quantization=True,
quantization_bins=4,
quantization_range=6,
process_group=None,
**kwargs,
):
super().__init__(
params,
lr=lr,
momentum=momentum,
weight_decay=0.0,
**kwargs,
)
self.decoupled_weight_decay = weight_decay
self.process_group = process_group
@torch.no_grad()
def step(self, closure=None):
super().step()
STRATEGY = DiLoCoStrategy(
optim_spec=OptimSpec(
torch.optim.AdamW,
{"lr": 0.001},
),
outer_optim_spec=OptimSpec(
SparseLoCo,
{
"lr": 0.8,
"momentum": 0.9,
"weight_decay": 0.05,
"top_k": 64,
"chunk_size": 64,
"use_dct": True,
"use_quantization": True,
"quantization_bins": 4,
"quantization_range": 6,
},
),
lr_scheduler="lambda_cosine",
lr_scheduler_kwargs={
"warmup_steps": 800,
"max_steps": 100,
},
max_norm=1.5,
H=25,
)
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment