Created
February 13, 2026 03:28
-
-
Save ejmejm/39cab180a27012039f51e3a19cd94835 to your computer and use it in GitHub Desktop.
Equinox Optimizer
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
| from typing import Optional, Tuple | |
| import equinox as eqx | |
| from jaxtyping import PyTree | |
| import optax | |
| class EqxOptimizer(eqx.Module): | |
| name: str = eqx.field(static=True) | |
| optimizer: optax.GradientTransformation = eqx.field(static=True) | |
| filter_spec: Optional[PyTree] = eqx.field(default=None, static=True) | |
| state: PyTree | |
| def __init__( | |
| self, | |
| optimizer: optax.GradientTransformation, | |
| model: eqx.Module, | |
| filter_spec: Optional[PyTree] = None, | |
| name: Optional[str] = None, | |
| ): | |
| self.optimizer = optimizer | |
| self.filter_spec = filter_spec | |
| if filter_spec is not None: | |
| trainable_params = eqx.filter(model, filter_spec) | |
| else: | |
| trainable_params = model | |
| self.state = self.optimizer.init(trainable_params) | |
| self.name = name | |
| def with_update(self, grads, model) -> Tuple[PyTree, 'EqxOptimizer']: | |
| """Update the optimizer state and return a new optimizer.""" | |
| if self.filter_spec is not None: | |
| if isinstance(grads, tuple): | |
| grads = tuple(eqx.filter(g, self.filter_spec) for g in grads) | |
| else: | |
| grads = eqx.filter(grads, self.filter_spec) | |
| model = eqx.filter(model, self.filter_spec) | |
| updates, new_state = self.optimizer.update(grads, self.state, model) | |
| return updates, eqx.tree_at(lambda x: x.state, self, new_state) |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment