Skip to content

Instantly share code, notes, and snippets.

@shakes76
Created February 8, 2026 01:37
Show Gist options
  • Select an option

  • Save shakes76/21d64940bfb065ef4020a53c052a6890 to your computer and use it in GitHub Desktop.

Select an option

Save shakes76/21d64940bfb065ef4020a53c052a6890 to your computer and use it in GitHub Desktop.
Adaptive Gradient Clipping in JAX
import jax
import jax.numpy as jnp
@jax.jit
def l2_norm(tree):
"""Compute the l2 norm of a pytree of arrays. Useful for weight decay."""
leaves, _ = jax.tree.flatten(tree)
return jnp.sqrt(sum(jnp.vdot(x, x) for x in leaves))
@jax.jit
def adaptive_grad_clip(params, grads, clip_factor=0.01, eps=1e-3):
'''
Adaptive gradient clipping (AGC) (as proposed in High-Performance
Large-Scale Image Recognition Without Normalization1): http://arxiv.org/abs/2102.06171
The paper attributes AGC as a crucial component in order to train
deep neural networks without batch normalization
'''
p_norm = l2_norm(params)
max_norm = jnp.maximum(p_norm, eps) * clip_factor
norm = l2_norm(grads)
normalize = lambda g: jnp.where(norm < max_norm, g, g * (max_norm / jnp.maximum(norm, 1e-6)))
return jax.tree.map(normalize, grads)
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment