Created
February 9, 2026 19:58
-
-
Save catid/5da6d328e9e4cfc2c106513e326b9ad1 to your computer and use it in GitHub Desktop.
Implementation of AdaGO - About the same speed as AdamW but much better results for small local training runs
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
| from __future__ import annotations | |
| from collections import defaultdict | |
| from typing import Iterable | |
| import torch | |
| from torch import Tensor | |
| from torch.optim import Optimizer | |
| from torch.optim._functional import adamw as functional_adamw | |
| _NS_COEFFS: list[tuple[float, float, float]] = [ | |
| (4.0848, -6.8946, 2.9270), | |
| (3.9505, -6.3029, 2.6377), | |
| (3.7418, -5.5913, 2.3037), | |
| (2.8769, -3.1427, 1.2046), | |
| (2.8366, -3.0525, 1.2012), | |
| ] | |
| def _normalize_vector(x: Tensor, eps: float) -> Tensor: | |
| if x.ndim == 0: | |
| return x / x.abs().clamp_min(eps) | |
| return x / torch.linalg.vector_norm(x).clamp_min(eps) | |
| def _as_matrix(x: Tensor) -> tuple[Tensor, torch.Size | None]: | |
| if x.ndim < 2: | |
| return x, None | |
| return x.reshape(x.shape[0], -1), x.shape | |
| def _orthogonalize_svd(x: Tensor, eps: float) -> Tensor: | |
| matrix, original_shape = _as_matrix(x) | |
| if original_shape is None: | |
| return _normalize_vector(matrix, eps) | |
| work = matrix.float() | |
| u, _, vh = torch.linalg.svd(work, full_matrices=False) | |
| ortho = (u @ vh).to(dtype=x.dtype) | |
| return ortho.reshape(original_shape) | |
| def _ns_coeffs(num_steps: int) -> list[tuple[float, float, float]]: | |
| if num_steps < 1: | |
| raise ValueError("ns_steps must be >= 1") | |
| if num_steps <= len(_NS_COEFFS): | |
| return _NS_COEFFS[:num_steps] | |
| return _NS_COEFFS + [_NS_COEFFS[-1]] * (num_steps - len(_NS_COEFFS)) | |
| def _orthogonalize_newton_schulz_heavyball( | |
| x: Tensor, | |
| eps: float, | |
| ns_steps: int, | |
| compute_dtype: torch.dtype, | |
| ) -> Tensor: | |
| matrix, original_shape = _as_matrix(x) | |
| if original_shape is None: | |
| return _normalize_vector(matrix, eps) | |
| work = matrix.to(dtype=compute_dtype) | |
| transposed = False | |
| if work.shape[0] > work.shape[1]: | |
| work = work.t() | |
| transposed = True | |
| norm = torch.linalg.vector_norm(work.float()).to(dtype=work.dtype).clamp_min(eps) | |
| work = work / norm | |
| for a, b, c in _ns_coeffs(ns_steps): | |
| s = work @ work.t() | |
| y = c * s | |
| y.diagonal().add_(b) | |
| y = y @ s | |
| y.diagonal().add_(a) | |
| work = y @ work | |
| if transposed: | |
| work = work.t() | |
| return work.to(dtype=x.dtype).reshape(original_shape) | |
| def _orthogonalize_hybrid( | |
| x: Tensor, | |
| eps: float, | |
| svd_max_dim: int, | |
| ns_steps: int, | |
| compute_dtype: torch.dtype, | |
| ) -> Tensor: | |
| matrix, original_shape = _as_matrix(x) | |
| if original_shape is None: | |
| return _normalize_vector(matrix, eps) | |
| m, n = matrix.shape | |
| # Empirically best dispatch from this project: | |
| # - rank-1 matrices: SVD faster | |
| # - tiny matrices: SVD competitive / often faster | |
| # - otherwise: NS-heavyball much faster with strong quality | |
| if min(m, n) == 1 or max(m, n) <= svd_max_dim: | |
| return _orthogonalize_svd(x, eps) | |
| return _orthogonalize_newton_schulz_heavyball( | |
| x, | |
| eps=eps, | |
| ns_steps=ns_steps, | |
| compute_dtype=compute_dtype, | |
| ) | |
| def _stack_as_matrices(xs: list[Tensor], compute_dtype: torch.dtype) -> tuple[Tensor, int, int, torch.Size]: | |
| if not xs: | |
| raise ValueError("Expected at least one tensor.") | |
| first = xs[0] | |
| if first.ndim < 2: | |
| raise ValueError("Expected matrix-like tensors.") | |
| m = first.shape[0] | |
| n = first.numel() // m | |
| mats = torch.stack([x.reshape(m, n).to(dtype=compute_dtype) for x in xs], dim=0) | |
| return mats, m, n, first.shape | |
| def _orthogonalize_svd_batched(xs: list[Tensor], eps: float) -> list[Tensor]: | |
| if len(xs) == 1: | |
| return [_orthogonalize_svd(xs[0], eps)] | |
| mats, _, _, original_shape = _stack_as_matrices(xs, compute_dtype=torch.float32) | |
| u, _, vh = torch.linalg.svd(mats, full_matrices=False) | |
| ortho = torch.bmm(u, vh).to(dtype=xs[0].dtype) | |
| return [ortho[i].reshape(original_shape) for i in range(ortho.shape[0])] | |
| def _orthogonalize_newton_schulz_heavyball_batched( | |
| xs: list[Tensor], | |
| eps: float, | |
| ns_steps: int, | |
| compute_dtype: torch.dtype, | |
| ) -> list[Tensor]: | |
| if len(xs) == 1: | |
| return [ | |
| _orthogonalize_newton_schulz_heavyball( | |
| xs[0], | |
| eps=eps, | |
| ns_steps=ns_steps, | |
| compute_dtype=compute_dtype, | |
| ), | |
| ] | |
| work, m, n, original_shape = _stack_as_matrices(xs, compute_dtype=compute_dtype) | |
| transposed = False | |
| if m > n: | |
| work = work.transpose(1, 2) | |
| transposed = True | |
| norms = torch.linalg.vector_norm(work.float(), dim=(1, 2), keepdim=True).to(dtype=work.dtype).clamp_min(eps) | |
| work = work / norms | |
| for a, b, c in _ns_coeffs(ns_steps): | |
| s = torch.bmm(work, work.transpose(1, 2)) | |
| y = c * s | |
| y.diagonal(dim1=-2, dim2=-1).add_(b) | |
| y = torch.bmm(y, s) | |
| y.diagonal(dim1=-2, dim2=-1).add_(a) | |
| work = torch.bmm(y, work) | |
| if transposed: | |
| work = work.transpose(1, 2) | |
| work = work.to(dtype=xs[0].dtype) | |
| return [work[i].reshape(original_shape) for i in range(work.shape[0])] | |
| def _orthogonalize_hybrid_batched( | |
| xs: list[Tensor], | |
| eps: float, | |
| svd_max_dim: int, | |
| ns_steps: int, | |
| compute_dtype: torch.dtype, | |
| ) -> list[Tensor]: | |
| if not xs: | |
| return [] | |
| if xs[0].ndim < 2: | |
| return [_normalize_vector(x, eps) for x in xs] | |
| m = xs[0].shape[0] | |
| n = xs[0].numel() // m | |
| if min(m, n) == 1 or max(m, n) <= svd_max_dim: | |
| return _orthogonalize_svd_batched(xs, eps) | |
| return _orthogonalize_newton_schulz_heavyball_batched( | |
| xs, | |
| eps=eps, | |
| ns_steps=ns_steps, | |
| compute_dtype=compute_dtype, | |
| ) | |
| class AdaGO(Optimizer): | |
| """ | |
| Hybrid AdaGO optimizer. | |
| Update: | |
| M_t = mu * M_{t-1} + (1-mu) * G_t | |
| v_t^2 = v_{t-1}^2 + min(||G_t||, gamma)^2 | |
| alpha_t = max(eps_floor, lr * min(||G_t||, gamma) / v_t) | |
| theta_t = theta_{t-1} - alpha_t * Orth(M_t) | |
| For matrix-like params (ndim >= 2), Orth(.) uses a hybrid dispatcher: | |
| - SVD for tiny/rank-1 matrices | |
| - NS-heavyball elsewhere | |
| For scalar/vector params (ndim < 2), this optimizer optionally applies | |
| AdamW updates (paper-style hybrid setup with decoupled weight decay) | |
| instead of orthogonalized updates, via PyTorch's optimized functional path. | |
| """ | |
| def __init__( | |
| self, | |
| params: Iterable[Tensor], | |
| lr: float = 1e-2, | |
| momentum: float = 0.95, | |
| gamma: float = 1.0, | |
| eps_floor: float = 1e-6, | |
| v0: float = 1e-3, | |
| weight_decay: float = 0.0, | |
| orth_eps: float = 1e-7, | |
| svd_max_dim: int = 8, | |
| ns_steps: int = 5, | |
| orth_compute_dtype: str = "fp32", | |
| norm_compute_dtype: str = "fp32", | |
| cache_shape_groups: bool = True, | |
| fused_adamw_non_matrix: bool = False, | |
| use_adam_for_non_matrix: bool = True, | |
| assume_dense_grads: bool = False, | |
| orth_update_interval: int = 1, | |
| adamw_numel_threshold: int = 0, | |
| adam_beta1: float = 0.9, | |
| adam_beta2: float = 0.999, | |
| adam_eps: float = 1e-8, | |
| ) -> None: | |
| if lr <= 0: | |
| raise ValueError("lr must be > 0") | |
| if not (0 <= momentum < 1): | |
| raise ValueError("momentum must be in [0, 1)") | |
| if gamma <= 0: | |
| raise ValueError("gamma must be > 0") | |
| if eps_floor <= 0: | |
| raise ValueError("eps_floor must be > 0") | |
| if v0 <= 0: | |
| raise ValueError("v0 must be > 0") | |
| if weight_decay < 0: | |
| raise ValueError("weight_decay must be >= 0") | |
| if orth_eps <= 0: | |
| raise ValueError("orth_eps must be > 0") | |
| if svd_max_dim < 1: | |
| raise ValueError("svd_max_dim must be >= 1") | |
| if ns_steps < 1: | |
| raise ValueError("ns_steps must be >= 1") | |
| if orth_compute_dtype not in ("fp32", "bf16"): | |
| raise ValueError("orth_compute_dtype must be one of: fp32, bf16") | |
| if norm_compute_dtype not in ("fp32", "bf16"): | |
| raise ValueError("norm_compute_dtype must be one of: fp32, bf16") | |
| if not (0 <= adam_beta1 < 1): | |
| raise ValueError("adam_beta1 must be in [0, 1)") | |
| if not (0 <= adam_beta2 < 1): | |
| raise ValueError("adam_beta2 must be in [0, 1)") | |
| if adam_eps <= 0: | |
| raise ValueError("adam_eps must be > 0") | |
| if orth_update_interval < 1: | |
| raise ValueError("orth_update_interval must be >= 1") | |
| if adamw_numel_threshold < 0: | |
| raise ValueError("adamw_numel_threshold must be >= 0") | |
| defaults = dict( | |
| lr=lr, | |
| momentum=momentum, | |
| gamma=gamma, | |
| eps_floor=eps_floor, | |
| v0=v0, | |
| weight_decay=weight_decay, | |
| orth_eps=orth_eps, | |
| svd_max_dim=svd_max_dim, | |
| ns_steps=ns_steps, | |
| orth_compute_dtype=orth_compute_dtype, | |
| norm_compute_dtype=norm_compute_dtype, | |
| cache_shape_groups=cache_shape_groups, | |
| fused_adamw_non_matrix=fused_adamw_non_matrix, | |
| use_adam_for_non_matrix=use_adam_for_non_matrix, | |
| assume_dense_grads=assume_dense_grads, | |
| orth_update_interval=orth_update_interval, | |
| adamw_numel_threshold=adamw_numel_threshold, | |
| adam_beta1=adam_beta1, | |
| adam_beta2=adam_beta2, | |
| adam_eps=adam_eps, | |
| ) | |
| super().__init__(params, defaults) | |
| def _build_group_static_cache( | |
| self, | |
| group: dict, | |
| *, | |
| v0: float, | |
| use_adam_for_non_matrix: bool, | |
| orth_update_interval: int, | |
| adamw_numel_threshold: int, | |
| ) -> dict: | |
| matrix_params: list[Tensor] = [] | |
| momentum_buffers: list[Tensor] = [] | |
| v_sqs: list[Tensor] = [] | |
| orth_directions: list[Tensor | None] = [] | |
| vector_params: list[Tensor] = [] | |
| adam_exp_avgs: list[Tensor] = [] | |
| adam_exp_avg_sqs: list[Tensor] = [] | |
| adam_max_exp_avg_sqs: list[Tensor] = [] | |
| adam_state_steps: list[Tensor] = [] | |
| shape_to_indices: dict[tuple[int, ...], list[int]] = defaultdict(list) | |
| for p in group["params"]: | |
| state = self.state[p] | |
| use_adamw_param = use_adam_for_non_matrix and ( | |
| p.ndim < 2 or (adamw_numel_threshold > 0 and p.ndim >= 2 and p.numel() <= adamw_numel_threshold) | |
| ) | |
| if use_adamw_param: | |
| if len(state) == 0 or "adam_exp_avg" not in state: | |
| state["adam_exp_avg"] = torch.zeros_like(p, memory_format=torch.preserve_format) | |
| state["adam_exp_avg_sq"] = torch.zeros_like(p, memory_format=torch.preserve_format) | |
| state["adam_step"] = torch.tensor(0.0, dtype=torch.float32) | |
| vector_params.append(p) | |
| adam_exp_avgs.append(state["adam_exp_avg"]) | |
| adam_exp_avg_sqs.append(state["adam_exp_avg_sq"]) | |
| adam_state_steps.append(state["adam_step"]) | |
| continue | |
| if len(state) == 0 or "momentum_buffer" not in state: | |
| state["momentum_buffer"] = torch.zeros_like(p, memory_format=torch.preserve_format) | |
| state["v_sq"] = torch.full((), v0 * v0, dtype=torch.float32, device=p.device) | |
| if orth_update_interval > 1: | |
| if "orth_direction" not in state: | |
| state["orth_direction"] = torch.zeros_like(p, memory_format=torch.preserve_format) | |
| orth_direction = state["orth_direction"] | |
| else: | |
| orth_direction = None | |
| matrix_index = len(matrix_params) | |
| matrix_params.append(p) | |
| momentum_buffers.append(state["momentum_buffer"]) | |
| v_sqs.append(state["v_sq"]) | |
| orth_directions.append(orth_direction) | |
| shape_to_indices[tuple(p.shape)].append(matrix_index) | |
| shape_groups = list(shape_to_indices.values()) | |
| shape_group_items = [ | |
| ( | |
| [matrix_params[i] for i in indices], | |
| [momentum_buffers[i] for i in indices], | |
| [v_sqs[i] for i in indices], | |
| [orth_directions[i] for i in indices], | |
| ) | |
| for indices in shape_groups | |
| ] | |
| return { | |
| "group_param_count": len(group["params"]), | |
| "use_adam_for_non_matrix": use_adam_for_non_matrix, | |
| "orth_update_interval": orth_update_interval, | |
| "adamw_numel_threshold": adamw_numel_threshold, | |
| "matrix_params": matrix_params, | |
| "momentum_buffers": momentum_buffers, | |
| "v_sqs": v_sqs, | |
| "orth_directions": orth_directions, | |
| "vector_params": vector_params, | |
| "adam_exp_avgs": adam_exp_avgs, | |
| "adam_exp_avg_sqs": adam_exp_avg_sqs, | |
| "adam_max_exp_avg_sqs": adam_max_exp_avg_sqs, | |
| "adam_state_steps": adam_state_steps, | |
| "shape_groups": shape_groups, | |
| "shape_group_items": shape_group_items, | |
| } | |
| @staticmethod | |
| def _cache_valid( | |
| cache: dict | None, | |
| *, | |
| group: dict, | |
| use_adam_for_non_matrix: bool, | |
| orth_update_interval: int, | |
| adamw_numel_threshold: int, | |
| ) -> bool: | |
| if not isinstance(cache, dict): | |
| return False | |
| if cache.get("group_param_count") != len(group["params"]): | |
| return False | |
| if cache.get("use_adam_for_non_matrix") != use_adam_for_non_matrix: | |
| return False | |
| if cache.get("orth_update_interval") != orth_update_interval: | |
| return False | |
| if cache.get("adamw_numel_threshold") != adamw_numel_threshold: | |
| return False | |
| return True | |
| def _collect_dynamic_group_state( | |
| self, | |
| group: dict, | |
| *, | |
| v0: float, | |
| use_adam_for_non_matrix: bool, | |
| orth_update_interval: int, | |
| adamw_numel_threshold: int, | |
| ) -> tuple[ | |
| list[Tensor], | |
| list[Tensor], | |
| list[Tensor], | |
| list[Tensor], | |
| list[Tensor | None], | |
| list[Tensor], | |
| list[Tensor], | |
| list[Tensor], | |
| list[Tensor], | |
| list[Tensor], | |
| list[list[int]], | |
| ]: | |
| matrix_params: list[Tensor] = [] | |
| matrix_grads: list[Tensor] = [] | |
| momentum_buffers: list[Tensor] = [] | |
| v_sqs: list[Tensor] = [] | |
| orth_directions: list[Tensor | None] = [] | |
| vector_params: list[Tensor] = [] | |
| vector_grads: list[Tensor] = [] | |
| adam_exp_avgs: list[Tensor] = [] | |
| adam_exp_avg_sqs: list[Tensor] = [] | |
| adam_max_exp_avg_sqs: list[Tensor] = [] | |
| adam_state_steps: list[Tensor] = [] | |
| shape_to_indices: dict[tuple[int, ...], list[int]] = defaultdict(list) | |
| for p in group["params"]: | |
| grad = p.grad | |
| if grad is None: | |
| continue | |
| if grad.is_sparse: | |
| raise RuntimeError("AdaGO does not support sparse gradients.") | |
| state = self.state[p] | |
| use_adamw_param = use_adam_for_non_matrix and ( | |
| p.ndim < 2 or (adamw_numel_threshold > 0 and p.ndim >= 2 and p.numel() <= adamw_numel_threshold) | |
| ) | |
| if use_adamw_param: | |
| if len(state) == 0 or "adam_exp_avg" not in state: | |
| state["adam_exp_avg"] = torch.zeros_like(p, memory_format=torch.preserve_format) | |
| state["adam_exp_avg_sq"] = torch.zeros_like(p, memory_format=torch.preserve_format) | |
| state["adam_step"] = torch.tensor(0.0, dtype=torch.float32) | |
| vector_params.append(p) | |
| vector_grads.append(grad) | |
| adam_exp_avgs.append(state["adam_exp_avg"]) | |
| adam_exp_avg_sqs.append(state["adam_exp_avg_sq"]) | |
| adam_state_steps.append(state["adam_step"]) | |
| else: | |
| if len(state) == 0 or "momentum_buffer" not in state: | |
| state["momentum_buffer"] = torch.zeros_like(p, memory_format=torch.preserve_format) | |
| state["v_sq"] = torch.full((), v0 * v0, dtype=torch.float32, device=p.device) | |
| if orth_update_interval > 1: | |
| if "orth_direction" not in state: | |
| state["orth_direction"] = torch.zeros_like(p, memory_format=torch.preserve_format) | |
| orth_direction = state["orth_direction"] | |
| else: | |
| orth_direction = None | |
| matrix_idx = len(matrix_params) | |
| matrix_params.append(p) | |
| matrix_grads.append(grad) | |
| momentum_buffers.append(state["momentum_buffer"]) | |
| v_sqs.append(state["v_sq"]) | |
| orth_directions.append(orth_direction) | |
| shape_to_indices[tuple(p.shape)].append(matrix_idx) | |
| return ( | |
| matrix_params, | |
| matrix_grads, | |
| momentum_buffers, | |
| v_sqs, | |
| orth_directions, | |
| vector_params, | |
| vector_grads, | |
| adam_exp_avgs, | |
| adam_exp_avg_sqs, | |
| adam_max_exp_avg_sqs, | |
| adam_state_steps, | |
| list(shape_to_indices.values()), | |
| ) | |
| @torch.no_grad() | |
| def step(self, closure=None): | |
| loss = None | |
| if closure is not None: | |
| with torch.enable_grad(): | |
| loss = closure() | |
| for group in self.param_groups: | |
| lr = float(group["lr"]) | |
| momentum = float(group["momentum"]) | |
| gamma = float(group["gamma"]) | |
| eps_floor = float(group["eps_floor"]) | |
| v0 = float(group["v0"]) | |
| weight_decay = float(group["weight_decay"]) | |
| orth_eps = float(group["orth_eps"]) | |
| svd_max_dim = int(group["svd_max_dim"]) | |
| ns_steps = int(group["ns_steps"]) | |
| orth_compute_dtype = str(group["orth_compute_dtype"]) | |
| norm_compute_dtype = str(group["norm_compute_dtype"]) | |
| cache_shape_groups = bool(group["cache_shape_groups"]) | |
| fused_adamw_non_matrix = bool(group["fused_adamw_non_matrix"]) | |
| use_adam_for_non_matrix = bool(group["use_adam_for_non_matrix"]) | |
| assume_dense_grads = bool(group.get("assume_dense_grads", False)) | |
| orth_update_interval = int(group.get("orth_update_interval", 1)) | |
| adamw_numel_threshold = int(group.get("adamw_numel_threshold", 0)) | |
| adam_beta1 = float(group["adam_beta1"]) | |
| adam_beta2 = float(group["adam_beta2"]) | |
| adam_eps = float(group["adam_eps"]) | |
| orth_dtype = torch.float32 if orth_compute_dtype == "fp32" else torch.bfloat16 | |
| norm_dtype = torch.float32 if norm_compute_dtype == "fp32" else torch.bfloat16 | |
| group_step = int(group.get("_adago_step", 0)) + 1 | |
| group["_adago_step"] = group_step | |
| recompute_orth = ((group_step - 1) % orth_update_interval) == 0 | |
| cache = group.get("_adago_static_cache") | |
| if not self._cache_valid( | |
| cache, | |
| group=group, | |
| use_adam_for_non_matrix=use_adam_for_non_matrix, | |
| orth_update_interval=orth_update_interval, | |
| adamw_numel_threshold=adamw_numel_threshold, | |
| ): | |
| cache = self._build_group_static_cache( | |
| group, | |
| v0=v0, | |
| use_adam_for_non_matrix=use_adam_for_non_matrix, | |
| orth_update_interval=orth_update_interval, | |
| adamw_numel_threshold=adamw_numel_threshold, | |
| ) | |
| group["_adago_static_cache"] = cache | |
| matrix_params: list[Tensor] = cache["matrix_params"] | |
| momentum_buffers: list[Tensor] = cache["momentum_buffers"] | |
| v_sqs: list[Tensor] = cache["v_sqs"] | |
| orth_directions: list[Tensor | None] = cache["orth_directions"] | |
| vector_params: list[Tensor] = cache["vector_params"] | |
| adam_exp_avgs: list[Tensor] = cache["adam_exp_avgs"] | |
| adam_exp_avg_sqs: list[Tensor] = cache["adam_exp_avg_sqs"] | |
| adam_max_exp_avg_sqs: list[Tensor] = cache["adam_max_exp_avg_sqs"] | |
| adam_state_steps: list[Tensor] = cache["adam_state_steps"] | |
| cached_shape_groups: list[list[int]] = cache["shape_groups"] | |
| cached_shape_group_items = cache["shape_group_items"] | |
| matrix_grads: list[Tensor] | |
| vector_grads: list[Tensor] | |
| use_dynamic_state_collection = False | |
| if assume_dense_grads: | |
| matrix_grads = [p.grad for p in matrix_params] | |
| vector_grads = [p.grad for p in vector_params] | |
| if any(g is None for g in matrix_grads) or any(g is None for g in vector_grads): | |
| use_dynamic_state_collection = True | |
| else: | |
| if any(g.is_sparse for g in matrix_grads) or any(g.is_sparse for g in vector_grads): | |
| raise RuntimeError("AdaGO does not support sparse gradients.") | |
| else: | |
| all_matrix_have_grads = True | |
| matrix_grads = [] | |
| for p in matrix_params: | |
| grad = p.grad | |
| if grad is None: | |
| all_matrix_have_grads = False | |
| break | |
| if grad.is_sparse: | |
| raise RuntimeError("AdaGO does not support sparse gradients.") | |
| matrix_grads.append(grad) | |
| all_vector_have_grads = True | |
| vector_grads = [] | |
| for p in vector_params: | |
| grad = p.grad | |
| if grad is None: | |
| all_vector_have_grads = False | |
| break | |
| if grad.is_sparse: | |
| raise RuntimeError("AdaGO does not support sparse gradients.") | |
| vector_grads.append(grad) | |
| if not all_matrix_have_grads or not all_vector_have_grads: | |
| use_dynamic_state_collection = True | |
| if use_dynamic_state_collection: | |
| ( | |
| matrix_params, | |
| matrix_grads, | |
| momentum_buffers, | |
| v_sqs, | |
| orth_directions, | |
| vector_params, | |
| vector_grads, | |
| adam_exp_avgs, | |
| adam_exp_avg_sqs, | |
| adam_max_exp_avg_sqs, | |
| adam_state_steps, | |
| shape_groups_runtime, | |
| ) = self._collect_dynamic_group_state( | |
| group, | |
| v0=v0, | |
| use_adam_for_non_matrix=use_adam_for_non_matrix, | |
| orth_update_interval=orth_update_interval, | |
| adamw_numel_threshold=adamw_numel_threshold, | |
| ) | |
| use_cached_shape_group_items = False | |
| else: | |
| shape_groups_runtime = cached_shape_groups | |
| use_cached_shape_group_items = cache_shape_groups and assume_dense_grads | |
| if not matrix_params and not vector_params: | |
| continue | |
| if weight_decay != 0: | |
| if matrix_params: | |
| torch._foreach_mul_(matrix_params, 1.0 - lr * weight_decay) | |
| if matrix_params: | |
| torch._foreach_mul_(momentum_buffers, momentum) | |
| torch._foreach_add_(momentum_buffers, matrix_grads, alpha=1.0 - momentum) | |
| if use_cached_shape_group_items: | |
| shape_group_items = cached_shape_group_items | |
| shape_groups = None | |
| elif cache_shape_groups: | |
| shape_groups = shape_groups_runtime | |
| shape_group_items = None | |
| else: | |
| shape_to_indices: dict[tuple[int, ...], list[int]] = defaultdict(list) | |
| for i, p in enumerate(matrix_params): | |
| shape_to_indices[tuple(p.shape)].append(i) | |
| shape_groups = list(shape_to_indices.values()) | |
| shape_group_items = None | |
| if shape_group_items is not None: | |
| iterable_groups = ( | |
| ( | |
| params_group, | |
| [p.grad for p in params_group], | |
| momentum_group, | |
| v_group, | |
| orth_direction_group, | |
| ) | |
| for (params_group, momentum_group, v_group, orth_direction_group) in shape_group_items | |
| ) | |
| else: | |
| iterable_groups = ( | |
| ( | |
| [matrix_params[i] for i in indices], | |
| [matrix_grads[i] for i in indices], | |
| [momentum_buffers[i] for i in indices], | |
| [v_sqs[i] for i in indices], | |
| [orth_directions[i] for i in indices], | |
| ) | |
| for indices in shape_groups | |
| ) | |
| for params_group, grads_group, momentum_group, v_group, orth_direction_group in iterable_groups: | |
| if not grads_group or grads_group[0] is None: | |
| continue | |
| first_grad = grads_group[0] | |
| if first_grad.ndim < 2: | |
| g_norms = torch.stack( | |
| [torch.linalg.vector_norm(g.to(dtype=norm_dtype)) for g in grads_group], | |
| dim=0, | |
| ) | |
| else: | |
| m = first_grad.shape[0] | |
| n = first_grad.numel() // m | |
| g_work = torch.stack([g.reshape(m, n).to(dtype=norm_dtype) for g in grads_group], dim=0) | |
| g_norms = torch.linalg.vector_norm(g_work, dim=(1, 2)) | |
| g_norms = g_norms.clamp_max(gamma) | |
| v_vals = torch.stack(v_group, dim=0) | |
| v_vals.add_(g_norms.square()) | |
| scales = torch.clamp_min(lr * g_norms / v_vals.sqrt(), eps_floor) | |
| for v_sq, new_v in zip(v_group, v_vals): | |
| v_sq.copy_(new_v) | |
| if recompute_orth: | |
| directions = _orthogonalize_hybrid_batched( | |
| momentum_group, | |
| eps=orth_eps, | |
| svd_max_dim=svd_max_dim, | |
| ns_steps=ns_steps, | |
| compute_dtype=orth_dtype, | |
| ) | |
| if orth_update_interval > 1: | |
| for orth_cache, direction in zip(orth_direction_group, directions): | |
| if orth_cache is not None: | |
| orth_cache.copy_(direction) | |
| for direction, scale in zip(directions, scales): | |
| direction.mul_(scale.to(dtype=direction.dtype)) | |
| torch._foreach_add_(params_group, directions, alpha=-1.0) | |
| else: | |
| if any(direction is None for direction in orth_direction_group): | |
| directions = _orthogonalize_hybrid_batched( | |
| momentum_group, | |
| eps=orth_eps, | |
| svd_max_dim=svd_max_dim, | |
| ns_steps=ns_steps, | |
| compute_dtype=orth_dtype, | |
| ) | |
| if orth_update_interval > 1: | |
| for orth_cache, direction in zip(orth_direction_group, directions): | |
| if orth_cache is not None: | |
| orth_cache.copy_(direction) | |
| else: | |
| directions = orth_direction_group | |
| scaled_directions = [direction * scale.to(dtype=direction.dtype) for direction, scale in zip(directions, scales)] | |
| torch._foreach_add_(params_group, scaled_directions, alpha=-1.0) | |
| if vector_params: | |
| has_complex = any(torch.is_complex(p) for p in vector_params) | |
| try: | |
| functional_adamw( | |
| vector_params, | |
| vector_grads, | |
| adam_exp_avgs, | |
| adam_exp_avg_sqs, | |
| adam_max_exp_avg_sqs, | |
| adam_state_steps, | |
| foreach=None if fused_adamw_non_matrix else True, | |
| capturable=False, | |
| differentiable=False, | |
| fused=True if fused_adamw_non_matrix else None, | |
| grad_scale=None, | |
| found_inf=None, | |
| has_complex=has_complex, | |
| amsgrad=False, | |
| beta1=adam_beta1, | |
| beta2=adam_beta2, | |
| lr=lr, | |
| weight_decay=weight_decay, | |
| eps=adam_eps, | |
| maximize=False, | |
| ) | |
| except RuntimeError as exc: | |
| # Fallback path for devices/builds that do not support fused functional AdamW. | |
| if fused_adamw_non_matrix: | |
| functional_adamw( | |
| vector_params, | |
| vector_grads, | |
| adam_exp_avgs, | |
| adam_exp_avg_sqs, | |
| adam_max_exp_avg_sqs, | |
| adam_state_steps, | |
| foreach=True, | |
| capturable=False, | |
| differentiable=False, | |
| fused=None, | |
| grad_scale=None, | |
| found_inf=None, | |
| has_complex=has_complex, | |
| amsgrad=False, | |
| beta1=adam_beta1, | |
| beta2=adam_beta2, | |
| lr=lr, | |
| weight_decay=weight_decay, | |
| eps=adam_eps, | |
| maximize=False, | |
| ) | |
| else: | |
| raise exc | |
| return loss |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment