Skip to content

Instantly share code, notes, and snippets.

@shakes76
Created February 8, 2026 01:39
Show Gist options
  • Select an option

  • Save shakes76/c6012bcf40d24396d7a7b2a6820da230 to your computer and use it in GitHub Desktop.

Select an option

Save shakes76/c6012bcf40d24396d7a7b2a6820da230 to your computer and use it in GitHub Desktop.
SQNL (Square non-linearity) activation function in JAX
import jax
import jax.numpy as jnp
# square non-linearity
# https://ieeexplore.ieee.org/document/8489043
@jax.jit
def sqnl(x):
# Clips the values to be between -2 and 2 for the activation range
# and calculates the square nonlinearity accordingly.
# Range [0, 2]
pos = jnp.where(jnp.logical_and(0.0 <= x, x <= 2.0), x - jnp.square(x)/4.0, x)
# Range [-2, 0)
neg = jnp.where(jnp.logical_and(-2.0 <= x, x < 0.0), x + jnp.square(x)/4.0, pos)
# Cap values outside (-2, 2) to -1 or 1
return jnp.clip(neg, -1.0, 1.0)
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment