Skip to content

Instantly share code, notes, and snippets.

@catid
Created February 9, 2026 19:58
Show Gist options
  • Select an option

  • Save catid/5da6d328e9e4cfc2c106513e326b9ad1 to your computer and use it in GitHub Desktop.

Select an option

Save catid/5da6d328e9e4cfc2c106513e326b9ad1 to your computer and use it in GitHub Desktop.
Implementation of AdaGO - About the same speed as AdamW but much better results for small local training runs
from __future__ import annotations
from collections import defaultdict
from typing import Iterable
import torch
from torch import Tensor
from torch.optim import Optimizer
from torch.optim._functional import adamw as functional_adamw
_NS_COEFFS: list[tuple[float, float, float]] = [
(4.0848, -6.8946, 2.9270),
(3.9505, -6.3029, 2.6377),
(3.7418, -5.5913, 2.3037),
(2.8769, -3.1427, 1.2046),
(2.8366, -3.0525, 1.2012),
]
def _normalize_vector(x: Tensor, eps: float) -> Tensor:
if x.ndim == 0:
return x / x.abs().clamp_min(eps)
return x / torch.linalg.vector_norm(x).clamp_min(eps)
def _as_matrix(x: Tensor) -> tuple[Tensor, torch.Size | None]:
if x.ndim < 2:
return x, None
return x.reshape(x.shape[0], -1), x.shape
def _orthogonalize_svd(x: Tensor, eps: float) -> Tensor:
matrix, original_shape = _as_matrix(x)
if original_shape is None:
return _normalize_vector(matrix, eps)
work = matrix.float()
u, _, vh = torch.linalg.svd(work, full_matrices=False)
ortho = (u @ vh).to(dtype=x.dtype)
return ortho.reshape(original_shape)
def _ns_coeffs(num_steps: int) -> list[tuple[float, float, float]]:
if num_steps < 1:
raise ValueError("ns_steps must be >= 1")
if num_steps <= len(_NS_COEFFS):
return _NS_COEFFS[:num_steps]
return _NS_COEFFS + [_NS_COEFFS[-1]] * (num_steps - len(_NS_COEFFS))
def _orthogonalize_newton_schulz_heavyball(
x: Tensor,
eps: float,
ns_steps: int,
compute_dtype: torch.dtype,
) -> Tensor:
matrix, original_shape = _as_matrix(x)
if original_shape is None:
return _normalize_vector(matrix, eps)
work = matrix.to(dtype=compute_dtype)
transposed = False
if work.shape[0] > work.shape[1]:
work = work.t()
transposed = True
norm = torch.linalg.vector_norm(work.float()).to(dtype=work.dtype).clamp_min(eps)
work = work / norm
for a, b, c in _ns_coeffs(ns_steps):
s = work @ work.t()
y = c * s
y.diagonal().add_(b)
y = y @ s
y.diagonal().add_(a)
work = y @ work
if transposed:
work = work.t()
return work.to(dtype=x.dtype).reshape(original_shape)
def _orthogonalize_hybrid(
x: Tensor,
eps: float,
svd_max_dim: int,
ns_steps: int,
compute_dtype: torch.dtype,
) -> Tensor:
matrix, original_shape = _as_matrix(x)
if original_shape is None:
return _normalize_vector(matrix, eps)
m, n = matrix.shape
# Empirically best dispatch from this project:
# - rank-1 matrices: SVD faster
# - tiny matrices: SVD competitive / often faster
# - otherwise: NS-heavyball much faster with strong quality
if min(m, n) == 1 or max(m, n) <= svd_max_dim:
return _orthogonalize_svd(x, eps)
return _orthogonalize_newton_schulz_heavyball(
x,
eps=eps,
ns_steps=ns_steps,
compute_dtype=compute_dtype,
)
def _stack_as_matrices(xs: list[Tensor], compute_dtype: torch.dtype) -> tuple[Tensor, int, int, torch.Size]:
if not xs:
raise ValueError("Expected at least one tensor.")
first = xs[0]
if first.ndim < 2:
raise ValueError("Expected matrix-like tensors.")
m = first.shape[0]
n = first.numel() // m
mats = torch.stack([x.reshape(m, n).to(dtype=compute_dtype) for x in xs], dim=0)
return mats, m, n, first.shape
def _orthogonalize_svd_batched(xs: list[Tensor], eps: float) -> list[Tensor]:
if len(xs) == 1:
return [_orthogonalize_svd(xs[0], eps)]
mats, _, _, original_shape = _stack_as_matrices(xs, compute_dtype=torch.float32)
u, _, vh = torch.linalg.svd(mats, full_matrices=False)
ortho = torch.bmm(u, vh).to(dtype=xs[0].dtype)
return [ortho[i].reshape(original_shape) for i in range(ortho.shape[0])]
def _orthogonalize_newton_schulz_heavyball_batched(
xs: list[Tensor],
eps: float,
ns_steps: int,
compute_dtype: torch.dtype,
) -> list[Tensor]:
if len(xs) == 1:
return [
_orthogonalize_newton_schulz_heavyball(
xs[0],
eps=eps,
ns_steps=ns_steps,
compute_dtype=compute_dtype,
),
]
work, m, n, original_shape = _stack_as_matrices(xs, compute_dtype=compute_dtype)
transposed = False
if m > n:
work = work.transpose(1, 2)
transposed = True
norms = torch.linalg.vector_norm(work.float(), dim=(1, 2), keepdim=True).to(dtype=work.dtype).clamp_min(eps)
work = work / norms
for a, b, c in _ns_coeffs(ns_steps):
s = torch.bmm(work, work.transpose(1, 2))
y = c * s
y.diagonal(dim1=-2, dim2=-1).add_(b)
y = torch.bmm(y, s)
y.diagonal(dim1=-2, dim2=-1).add_(a)
work = torch.bmm(y, work)
if transposed:
work = work.transpose(1, 2)
work = work.to(dtype=xs[0].dtype)
return [work[i].reshape(original_shape) for i in range(work.shape[0])]
def _orthogonalize_hybrid_batched(
xs: list[Tensor],
eps: float,
svd_max_dim: int,
ns_steps: int,
compute_dtype: torch.dtype,
) -> list[Tensor]:
if not xs:
return []
if xs[0].ndim < 2:
return [_normalize_vector(x, eps) for x in xs]
m = xs[0].shape[0]
n = xs[0].numel() // m
if min(m, n) == 1 or max(m, n) <= svd_max_dim:
return _orthogonalize_svd_batched(xs, eps)
return _orthogonalize_newton_schulz_heavyball_batched(
xs,
eps=eps,
ns_steps=ns_steps,
compute_dtype=compute_dtype,
)
class AdaGO(Optimizer):
"""
Hybrid AdaGO optimizer.
Update:
M_t = mu * M_{t-1} + (1-mu) * G_t
v_t^2 = v_{t-1}^2 + min(||G_t||, gamma)^2
alpha_t = max(eps_floor, lr * min(||G_t||, gamma) / v_t)
theta_t = theta_{t-1} - alpha_t * Orth(M_t)
For matrix-like params (ndim >= 2), Orth(.) uses a hybrid dispatcher:
- SVD for tiny/rank-1 matrices
- NS-heavyball elsewhere
For scalar/vector params (ndim < 2), this optimizer optionally applies
AdamW updates (paper-style hybrid setup with decoupled weight decay)
instead of orthogonalized updates, via PyTorch's optimized functional path.
"""
def __init__(
self,
params: Iterable[Tensor],
lr: float = 1e-2,
momentum: float = 0.95,
gamma: float = 1.0,
eps_floor: float = 1e-6,
v0: float = 1e-3,
weight_decay: float = 0.0,
orth_eps: float = 1e-7,
svd_max_dim: int = 8,
ns_steps: int = 5,
orth_compute_dtype: str = "fp32",
norm_compute_dtype: str = "fp32",
cache_shape_groups: bool = True,
fused_adamw_non_matrix: bool = False,
use_adam_for_non_matrix: bool = True,
assume_dense_grads: bool = False,
orth_update_interval: int = 1,
adamw_numel_threshold: int = 0,
adam_beta1: float = 0.9,
adam_beta2: float = 0.999,
adam_eps: float = 1e-8,
) -> None:
if lr <= 0:
raise ValueError("lr must be > 0")
if not (0 <= momentum < 1):
raise ValueError("momentum must be in [0, 1)")
if gamma <= 0:
raise ValueError("gamma must be > 0")
if eps_floor <= 0:
raise ValueError("eps_floor must be > 0")
if v0 <= 0:
raise ValueError("v0 must be > 0")
if weight_decay < 0:
raise ValueError("weight_decay must be >= 0")
if orth_eps <= 0:
raise ValueError("orth_eps must be > 0")
if svd_max_dim < 1:
raise ValueError("svd_max_dim must be >= 1")
if ns_steps < 1:
raise ValueError("ns_steps must be >= 1")
if orth_compute_dtype not in ("fp32", "bf16"):
raise ValueError("orth_compute_dtype must be one of: fp32, bf16")
if norm_compute_dtype not in ("fp32", "bf16"):
raise ValueError("norm_compute_dtype must be one of: fp32, bf16")
if not (0 <= adam_beta1 < 1):
raise ValueError("adam_beta1 must be in [0, 1)")
if not (0 <= adam_beta2 < 1):
raise ValueError("adam_beta2 must be in [0, 1)")
if adam_eps <= 0:
raise ValueError("adam_eps must be > 0")
if orth_update_interval < 1:
raise ValueError("orth_update_interval must be >= 1")
if adamw_numel_threshold < 0:
raise ValueError("adamw_numel_threshold must be >= 0")
defaults = dict(
lr=lr,
momentum=momentum,
gamma=gamma,
eps_floor=eps_floor,
v0=v0,
weight_decay=weight_decay,
orth_eps=orth_eps,
svd_max_dim=svd_max_dim,
ns_steps=ns_steps,
orth_compute_dtype=orth_compute_dtype,
norm_compute_dtype=norm_compute_dtype,
cache_shape_groups=cache_shape_groups,
fused_adamw_non_matrix=fused_adamw_non_matrix,
use_adam_for_non_matrix=use_adam_for_non_matrix,
assume_dense_grads=assume_dense_grads,
orth_update_interval=orth_update_interval,
adamw_numel_threshold=adamw_numel_threshold,
adam_beta1=adam_beta1,
adam_beta2=adam_beta2,
adam_eps=adam_eps,
)
super().__init__(params, defaults)
def _build_group_static_cache(
self,
group: dict,
*,
v0: float,
use_adam_for_non_matrix: bool,
orth_update_interval: int,
adamw_numel_threshold: int,
) -> dict:
matrix_params: list[Tensor] = []
momentum_buffers: list[Tensor] = []
v_sqs: list[Tensor] = []
orth_directions: list[Tensor | None] = []
vector_params: list[Tensor] = []
adam_exp_avgs: list[Tensor] = []
adam_exp_avg_sqs: list[Tensor] = []
adam_max_exp_avg_sqs: list[Tensor] = []
adam_state_steps: list[Tensor] = []
shape_to_indices: dict[tuple[int, ...], list[int]] = defaultdict(list)
for p in group["params"]:
state = self.state[p]
use_adamw_param = use_adam_for_non_matrix and (
p.ndim < 2 or (adamw_numel_threshold > 0 and p.ndim >= 2 and p.numel() <= adamw_numel_threshold)
)
if use_adamw_param:
if len(state) == 0 or "adam_exp_avg" not in state:
state["adam_exp_avg"] = torch.zeros_like(p, memory_format=torch.preserve_format)
state["adam_exp_avg_sq"] = torch.zeros_like(p, memory_format=torch.preserve_format)
state["adam_step"] = torch.tensor(0.0, dtype=torch.float32)
vector_params.append(p)
adam_exp_avgs.append(state["adam_exp_avg"])
adam_exp_avg_sqs.append(state["adam_exp_avg_sq"])
adam_state_steps.append(state["adam_step"])
continue
if len(state) == 0 or "momentum_buffer" not in state:
state["momentum_buffer"] = torch.zeros_like(p, memory_format=torch.preserve_format)
state["v_sq"] = torch.full((), v0 * v0, dtype=torch.float32, device=p.device)
if orth_update_interval > 1:
if "orth_direction" not in state:
state["orth_direction"] = torch.zeros_like(p, memory_format=torch.preserve_format)
orth_direction = state["orth_direction"]
else:
orth_direction = None
matrix_index = len(matrix_params)
matrix_params.append(p)
momentum_buffers.append(state["momentum_buffer"])
v_sqs.append(state["v_sq"])
orth_directions.append(orth_direction)
shape_to_indices[tuple(p.shape)].append(matrix_index)
shape_groups = list(shape_to_indices.values())
shape_group_items = [
(
[matrix_params[i] for i in indices],
[momentum_buffers[i] for i in indices],
[v_sqs[i] for i in indices],
[orth_directions[i] for i in indices],
)
for indices in shape_groups
]
return {
"group_param_count": len(group["params"]),
"use_adam_for_non_matrix": use_adam_for_non_matrix,
"orth_update_interval": orth_update_interval,
"adamw_numel_threshold": adamw_numel_threshold,
"matrix_params": matrix_params,
"momentum_buffers": momentum_buffers,
"v_sqs": v_sqs,
"orth_directions": orth_directions,
"vector_params": vector_params,
"adam_exp_avgs": adam_exp_avgs,
"adam_exp_avg_sqs": adam_exp_avg_sqs,
"adam_max_exp_avg_sqs": adam_max_exp_avg_sqs,
"adam_state_steps": adam_state_steps,
"shape_groups": shape_groups,
"shape_group_items": shape_group_items,
}
@staticmethod
def _cache_valid(
cache: dict | None,
*,
group: dict,
use_adam_for_non_matrix: bool,
orth_update_interval: int,
adamw_numel_threshold: int,
) -> bool:
if not isinstance(cache, dict):
return False
if cache.get("group_param_count") != len(group["params"]):
return False
if cache.get("use_adam_for_non_matrix") != use_adam_for_non_matrix:
return False
if cache.get("orth_update_interval") != orth_update_interval:
return False
if cache.get("adamw_numel_threshold") != adamw_numel_threshold:
return False
return True
def _collect_dynamic_group_state(
self,
group: dict,
*,
v0: float,
use_adam_for_non_matrix: bool,
orth_update_interval: int,
adamw_numel_threshold: int,
) -> tuple[
list[Tensor],
list[Tensor],
list[Tensor],
list[Tensor],
list[Tensor | None],
list[Tensor],
list[Tensor],
list[Tensor],
list[Tensor],
list[Tensor],
list[list[int]],
]:
matrix_params: list[Tensor] = []
matrix_grads: list[Tensor] = []
momentum_buffers: list[Tensor] = []
v_sqs: list[Tensor] = []
orth_directions: list[Tensor | None] = []
vector_params: list[Tensor] = []
vector_grads: list[Tensor] = []
adam_exp_avgs: list[Tensor] = []
adam_exp_avg_sqs: list[Tensor] = []
adam_max_exp_avg_sqs: list[Tensor] = []
adam_state_steps: list[Tensor] = []
shape_to_indices: dict[tuple[int, ...], list[int]] = defaultdict(list)
for p in group["params"]:
grad = p.grad
if grad is None:
continue
if grad.is_sparse:
raise RuntimeError("AdaGO does not support sparse gradients.")
state = self.state[p]
use_adamw_param = use_adam_for_non_matrix and (
p.ndim < 2 or (adamw_numel_threshold > 0 and p.ndim >= 2 and p.numel() <= adamw_numel_threshold)
)
if use_adamw_param:
if len(state) == 0 or "adam_exp_avg" not in state:
state["adam_exp_avg"] = torch.zeros_like(p, memory_format=torch.preserve_format)
state["adam_exp_avg_sq"] = torch.zeros_like(p, memory_format=torch.preserve_format)
state["adam_step"] = torch.tensor(0.0, dtype=torch.float32)
vector_params.append(p)
vector_grads.append(grad)
adam_exp_avgs.append(state["adam_exp_avg"])
adam_exp_avg_sqs.append(state["adam_exp_avg_sq"])
adam_state_steps.append(state["adam_step"])
else:
if len(state) == 0 or "momentum_buffer" not in state:
state["momentum_buffer"] = torch.zeros_like(p, memory_format=torch.preserve_format)
state["v_sq"] = torch.full((), v0 * v0, dtype=torch.float32, device=p.device)
if orth_update_interval > 1:
if "orth_direction" not in state:
state["orth_direction"] = torch.zeros_like(p, memory_format=torch.preserve_format)
orth_direction = state["orth_direction"]
else:
orth_direction = None
matrix_idx = len(matrix_params)
matrix_params.append(p)
matrix_grads.append(grad)
momentum_buffers.append(state["momentum_buffer"])
v_sqs.append(state["v_sq"])
orth_directions.append(orth_direction)
shape_to_indices[tuple(p.shape)].append(matrix_idx)
return (
matrix_params,
matrix_grads,
momentum_buffers,
v_sqs,
orth_directions,
vector_params,
vector_grads,
adam_exp_avgs,
adam_exp_avg_sqs,
adam_max_exp_avg_sqs,
adam_state_steps,
list(shape_to_indices.values()),
)
@torch.no_grad()
def step(self, closure=None):
loss = None
if closure is not None:
with torch.enable_grad():
loss = closure()
for group in self.param_groups:
lr = float(group["lr"])
momentum = float(group["momentum"])
gamma = float(group["gamma"])
eps_floor = float(group["eps_floor"])
v0 = float(group["v0"])
weight_decay = float(group["weight_decay"])
orth_eps = float(group["orth_eps"])
svd_max_dim = int(group["svd_max_dim"])
ns_steps = int(group["ns_steps"])
orth_compute_dtype = str(group["orth_compute_dtype"])
norm_compute_dtype = str(group["norm_compute_dtype"])
cache_shape_groups = bool(group["cache_shape_groups"])
fused_adamw_non_matrix = bool(group["fused_adamw_non_matrix"])
use_adam_for_non_matrix = bool(group["use_adam_for_non_matrix"])
assume_dense_grads = bool(group.get("assume_dense_grads", False))
orth_update_interval = int(group.get("orth_update_interval", 1))
adamw_numel_threshold = int(group.get("adamw_numel_threshold", 0))
adam_beta1 = float(group["adam_beta1"])
adam_beta2 = float(group["adam_beta2"])
adam_eps = float(group["adam_eps"])
orth_dtype = torch.float32 if orth_compute_dtype == "fp32" else torch.bfloat16
norm_dtype = torch.float32 if norm_compute_dtype == "fp32" else torch.bfloat16
group_step = int(group.get("_adago_step", 0)) + 1
group["_adago_step"] = group_step
recompute_orth = ((group_step - 1) % orth_update_interval) == 0
cache = group.get("_adago_static_cache")
if not self._cache_valid(
cache,
group=group,
use_adam_for_non_matrix=use_adam_for_non_matrix,
orth_update_interval=orth_update_interval,
adamw_numel_threshold=adamw_numel_threshold,
):
cache = self._build_group_static_cache(
group,
v0=v0,
use_adam_for_non_matrix=use_adam_for_non_matrix,
orth_update_interval=orth_update_interval,
adamw_numel_threshold=adamw_numel_threshold,
)
group["_adago_static_cache"] = cache
matrix_params: list[Tensor] = cache["matrix_params"]
momentum_buffers: list[Tensor] = cache["momentum_buffers"]
v_sqs: list[Tensor] = cache["v_sqs"]
orth_directions: list[Tensor | None] = cache["orth_directions"]
vector_params: list[Tensor] = cache["vector_params"]
adam_exp_avgs: list[Tensor] = cache["adam_exp_avgs"]
adam_exp_avg_sqs: list[Tensor] = cache["adam_exp_avg_sqs"]
adam_max_exp_avg_sqs: list[Tensor] = cache["adam_max_exp_avg_sqs"]
adam_state_steps: list[Tensor] = cache["adam_state_steps"]
cached_shape_groups: list[list[int]] = cache["shape_groups"]
cached_shape_group_items = cache["shape_group_items"]
matrix_grads: list[Tensor]
vector_grads: list[Tensor]
use_dynamic_state_collection = False
if assume_dense_grads:
matrix_grads = [p.grad for p in matrix_params]
vector_grads = [p.grad for p in vector_params]
if any(g is None for g in matrix_grads) or any(g is None for g in vector_grads):
use_dynamic_state_collection = True
else:
if any(g.is_sparse for g in matrix_grads) or any(g.is_sparse for g in vector_grads):
raise RuntimeError("AdaGO does not support sparse gradients.")
else:
all_matrix_have_grads = True
matrix_grads = []
for p in matrix_params:
grad = p.grad
if grad is None:
all_matrix_have_grads = False
break
if grad.is_sparse:
raise RuntimeError("AdaGO does not support sparse gradients.")
matrix_grads.append(grad)
all_vector_have_grads = True
vector_grads = []
for p in vector_params:
grad = p.grad
if grad is None:
all_vector_have_grads = False
break
if grad.is_sparse:
raise RuntimeError("AdaGO does not support sparse gradients.")
vector_grads.append(grad)
if not all_matrix_have_grads or not all_vector_have_grads:
use_dynamic_state_collection = True
if use_dynamic_state_collection:
(
matrix_params,
matrix_grads,
momentum_buffers,
v_sqs,
orth_directions,
vector_params,
vector_grads,
adam_exp_avgs,
adam_exp_avg_sqs,
adam_max_exp_avg_sqs,
adam_state_steps,
shape_groups_runtime,
) = self._collect_dynamic_group_state(
group,
v0=v0,
use_adam_for_non_matrix=use_adam_for_non_matrix,
orth_update_interval=orth_update_interval,
adamw_numel_threshold=adamw_numel_threshold,
)
use_cached_shape_group_items = False
else:
shape_groups_runtime = cached_shape_groups
use_cached_shape_group_items = cache_shape_groups and assume_dense_grads
if not matrix_params and not vector_params:
continue
if weight_decay != 0:
if matrix_params:
torch._foreach_mul_(matrix_params, 1.0 - lr * weight_decay)
if matrix_params:
torch._foreach_mul_(momentum_buffers, momentum)
torch._foreach_add_(momentum_buffers, matrix_grads, alpha=1.0 - momentum)
if use_cached_shape_group_items:
shape_group_items = cached_shape_group_items
shape_groups = None
elif cache_shape_groups:
shape_groups = shape_groups_runtime
shape_group_items = None
else:
shape_to_indices: dict[tuple[int, ...], list[int]] = defaultdict(list)
for i, p in enumerate(matrix_params):
shape_to_indices[tuple(p.shape)].append(i)
shape_groups = list(shape_to_indices.values())
shape_group_items = None
if shape_group_items is not None:
iterable_groups = (
(
params_group,
[p.grad for p in params_group],
momentum_group,
v_group,
orth_direction_group,
)
for (params_group, momentum_group, v_group, orth_direction_group) in shape_group_items
)
else:
iterable_groups = (
(
[matrix_params[i] for i in indices],
[matrix_grads[i] for i in indices],
[momentum_buffers[i] for i in indices],
[v_sqs[i] for i in indices],
[orth_directions[i] for i in indices],
)
for indices in shape_groups
)
for params_group, grads_group, momentum_group, v_group, orth_direction_group in iterable_groups:
if not grads_group or grads_group[0] is None:
continue
first_grad = grads_group[0]
if first_grad.ndim < 2:
g_norms = torch.stack(
[torch.linalg.vector_norm(g.to(dtype=norm_dtype)) for g in grads_group],
dim=0,
)
else:
m = first_grad.shape[0]
n = first_grad.numel() // m
g_work = torch.stack([g.reshape(m, n).to(dtype=norm_dtype) for g in grads_group], dim=0)
g_norms = torch.linalg.vector_norm(g_work, dim=(1, 2))
g_norms = g_norms.clamp_max(gamma)
v_vals = torch.stack(v_group, dim=0)
v_vals.add_(g_norms.square())
scales = torch.clamp_min(lr * g_norms / v_vals.sqrt(), eps_floor)
for v_sq, new_v in zip(v_group, v_vals):
v_sq.copy_(new_v)
if recompute_orth:
directions = _orthogonalize_hybrid_batched(
momentum_group,
eps=orth_eps,
svd_max_dim=svd_max_dim,
ns_steps=ns_steps,
compute_dtype=orth_dtype,
)
if orth_update_interval > 1:
for orth_cache, direction in zip(orth_direction_group, directions):
if orth_cache is not None:
orth_cache.copy_(direction)
for direction, scale in zip(directions, scales):
direction.mul_(scale.to(dtype=direction.dtype))
torch._foreach_add_(params_group, directions, alpha=-1.0)
else:
if any(direction is None for direction in orth_direction_group):
directions = _orthogonalize_hybrid_batched(
momentum_group,
eps=orth_eps,
svd_max_dim=svd_max_dim,
ns_steps=ns_steps,
compute_dtype=orth_dtype,
)
if orth_update_interval > 1:
for orth_cache, direction in zip(orth_direction_group, directions):
if orth_cache is not None:
orth_cache.copy_(direction)
else:
directions = orth_direction_group
scaled_directions = [direction * scale.to(dtype=direction.dtype) for direction, scale in zip(directions, scales)]
torch._foreach_add_(params_group, scaled_directions, alpha=-1.0)
if vector_params:
has_complex = any(torch.is_complex(p) for p in vector_params)
try:
functional_adamw(
vector_params,
vector_grads,
adam_exp_avgs,
adam_exp_avg_sqs,
adam_max_exp_avg_sqs,
adam_state_steps,
foreach=None if fused_adamw_non_matrix else True,
capturable=False,
differentiable=False,
fused=True if fused_adamw_non_matrix else None,
grad_scale=None,
found_inf=None,
has_complex=has_complex,
amsgrad=False,
beta1=adam_beta1,
beta2=adam_beta2,
lr=lr,
weight_decay=weight_decay,
eps=adam_eps,
maximize=False,
)
except RuntimeError as exc:
# Fallback path for devices/builds that do not support fused functional AdamW.
if fused_adamw_non_matrix:
functional_adamw(
vector_params,
vector_grads,
adam_exp_avgs,
adam_exp_avg_sqs,
adam_max_exp_avg_sqs,
adam_state_steps,
foreach=True,
capturable=False,
differentiable=False,
fused=None,
grad_scale=None,
found_inf=None,
has_complex=has_complex,
amsgrad=False,
beta1=adam_beta1,
beta2=adam_beta2,
lr=lr,
weight_decay=weight_decay,
eps=adam_eps,
maximize=False,
)
else:
raise exc
return loss
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment