Skip to content

Instantly share code, notes, and snippets.

@ejmejm
Created February 13, 2026 03:28
Show Gist options
  • Select an option

  • Save ejmejm/39cab180a27012039f51e3a19cd94835 to your computer and use it in GitHub Desktop.

Select an option

Save ejmejm/39cab180a27012039f51e3a19cd94835 to your computer and use it in GitHub Desktop.
Equinox Optimizer
from typing import Optional, Tuple
import equinox as eqx
from jaxtyping import PyTree
import optax
class EqxOptimizer(eqx.Module):
name: str = eqx.field(static=True)
optimizer: optax.GradientTransformation = eqx.field(static=True)
filter_spec: Optional[PyTree] = eqx.field(default=None, static=True)
state: PyTree
def __init__(
self,
optimizer: optax.GradientTransformation,
model: eqx.Module,
filter_spec: Optional[PyTree] = None,
name: Optional[str] = None,
):
self.optimizer = optimizer
self.filter_spec = filter_spec
if filter_spec is not None:
trainable_params = eqx.filter(model, filter_spec)
else:
trainable_params = model
self.state = self.optimizer.init(trainable_params)
self.name = name
def with_update(self, grads, model) -> Tuple[PyTree, 'EqxOptimizer']:
"""Update the optimizer state and return a new optimizer."""
if self.filter_spec is not None:
if isinstance(grads, tuple):
grads = tuple(eqx.filter(g, self.filter_spec) for g in grads)
else:
grads = eqx.filter(grads, self.filter_spec)
model = eqx.filter(model, self.filter_spec)
updates, new_state = self.optimizer.update(grads, self.state, model)
return updates, eqx.tree_at(lambda x: x.state, self, new_state)
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment