Skip to content

Instantly share code, notes, and snippets.

@howsiyu
Created January 2, 2026 01:31
Show Gist options
  • Select an option

  • Save howsiyu/493d45ba513a61dc609ca67c41f7e29e to your computer and use it in GitHub Desktop.

Select an option

Save howsiyu/493d45ba513a61dc609ca67c41f7e29e to your computer and use it in GitHub Desktop.
Implement scipy.linalg.onenormest in idiomatic JaX
from functools import partial
from typing import Any, NamedTuple, Protocol
import jax
import jax.numpy as jnp
from jax import lax
from jax import Array
class LinearOperator(Protocol):
dtype: Any
shape: tuple[int, int]
def matmat(self, X: Array) -> Array:
"""
Performs the operation A @ X.
"""
...
def rmatmat(self, X: Array) -> Array:
"""
Performs the operation A.H @ X where A.H denotes the conjugate tranpose of A.
"""
...
def cols(self, indices: Array) -> Array:
"""
Computes A[:, indices]. This can often be performed faster than using matmat.
"""
...
class OnenormestLoopState(NamedTuple):
break_flag: Array
# Number of loop iterations
k: Array
# Number of parallel columns of S we have detected so far.
# To avoid randomness, we replace columns of S that are parallel to S_old or another
# column of S by columns that alternate signs every i columns for i = 1, 2, ...
# This should be good enough as replacing parallel columns of S has little effect
# on most matrices anyway (and is not done for complex matrices).
s: Array
S: Array
# Boolean vector recording indices of used unit vectors e_j
ind_hist: Array
# Indices of the t largest columns we have encountered so far.
best_t_indices: Array
# 1-norms of columns in best_t_indices
best_t_estimates: Array
@partial(jax.jit, static_argnums=[2, 4, 5])
def onenormest(
key,
A: LinearOperator,
t: int,
itmax=5,
compute_v: bool = False,
compute_w: bool = False,
):
m, n = A.shape
if 4 * t >= n:
# If the operator size is small compared to t,
# then it is easier to compute the exact norm.
A_explicit = A.cols(jnp.arange(n))
A_abs = lax.abs(A_explicit).sum(axis=0)
j = jnp.argmax(A_abs)
est = A_abs[j]
else:
if jnp.issubdtype(A.dtype, jnp.complexfloating):
real_dtype = jnp.zeros((), dtype=A.dtype).real.dtype
u = jax.random.uniform(key, shape=(m, t), dtype=real_dtype)
S = jnp.exp(u * (jnp.pi * 2j))
else:
S = jax.random.bernoulli(key, shape=(m, t)).astype(A.dtype) * -2 + 1
# My experiment shows this performs similarly to the starting matrix with
# ones and rand{−1, 1}s as suggested by the paper on random matrices but
# vasly outperform the latter for the pathological triangular matrix
# example given in the paper.
starting_matrix = sign_round_up(A.rmatmat(S))
final_state = onenormest_core(A, starting_matrix, jnp.array(itmax))
j = final_state.best_t_indices[0]
est = final_state.best_t_estimates[0]
# Report the norm estimate along with some certificates of the estimate.
if compute_v or compute_w:
result = (est,)
if compute_v:
v = jnp.zeros(n, A.dtype).at[j].set(1)
result += (v,)
if compute_w:
result += (A.cols(j),)
return result
else:
return est
def onenormest_core(
A: LinearOperator, starting_matrix: Array, itmax: Array
) -> OnenormestLoopState:
"""
A deterministic variant of Algorithm 2.4 in
Nicholas J. Higham and Francoise Tisseur (2000),
"A Block Algorithm for Matrix 1-Norm Estimation,
with an Application to 1-Norm Pseudospectra."
that given a starting_matrix, returns t column indices
with their norms that estimate the largest t columns.
"""
m, n = A.shape
n2, t = starting_matrix.shape
if n != n2:
raise ValueError("A and starting_matrix must have compatible shapes!")
if t < 1:
raise ValueError("at least one column is required")
if 4 * t >= n:
raise ValueError("4 * t should be smaller than the order of A")
# If (itmax + 1) * t > n we could have s >= n which makes us loop
# indefinitely when resampling S's columns. In practice this algorithm
# only makes sense if n >> itmax * t so this shouldn't happen.
itmax = jnp.clip(itmax, min=2, max=n // t - 1)
Y = A.matmat(starting_matrix)
init_state = OnenormestLoopState(
break_flag=jnp.array(False),
k=jnp.array(1),
s=jnp.array(0),
S=sign_round_up(Y),
ind_hist=jnp.zeros(n, dtype=jnp.bool),
best_t_indices=jnp.zeros(t, dtype=int),
best_t_estimates=jnp.full(t, -jnp.abs(jnp.array(1, A.dtype))),
)
def cond_fun(state: OnenormestLoopState):
return ~state.break_flag
def body_fun(state: OnenormestLoopState) -> OnenormestLoopState:
S = state.S
s = state.s
k = state.k
if not jnp.issubdtype(A.dtype, jnp.complexfloating):
S, s = resample_S1(S, s)
h = jnp.abs(A.rmatmat(S)).max(axis=1) # (3) in the paper
ind_tmp = jnp.argsort(h, descending=True)
(best_t_indices_not_in_ind_hist,) = jnp.nonzero(
~state.ind_hist[ind_tmp], size=t
)
break_flag = (k >= 2) & (
(h[ind_tmp[0]] == h[state.best_t_indices[0]]) # break condition in (4)
| (best_t_indices_not_in_ind_hist[0] >= t) # break condition in (5)
)
ind = ind_tmp[best_t_indices_not_in_ind_hist]
ind_hist = state.ind_hist.at[ind].set(True)
Y = A.cols(ind)
Y_abs = lax.abs(Y).sum(axis=0)
est_old = state.best_t_estimates[0]
break_flag = break_flag | (Y_abs.max() <= est_old) # break condition in (1)
best_ests = jnp.append(state.best_t_estimates, Y_abs)
best_t_estimates, best_ests_sorted_indices = lax.top_k(best_ests, t)
best_t_indices = jnp.append(state.best_t_indices, ind)[best_ests_sorted_indices]
S_new = sign_round_up(Y)
if not jnp.issubdtype(A.dtype, jnp.complexfloating):
parallel_to_S = jnp.any(jnp.abs(S.T @ S_new) == m, axis=0)
break_flag = break_flag | jnp.all(parallel_to_S) # break condition in (2)
S_new, s = resample_S2(S_new, s, S, parallel_to_S)
return OnenormestLoopState(
break_flag=break_flag | (k >= itmax),
k=k + 1,
s=s,
S=S_new,
ind_hist=ind_hist,
best_t_indices=best_t_indices,
best_t_estimates=best_t_estimates,
)
return lax.while_loop(cond_fun, body_fun, init_state)
def sign_round_up(Y: Array):
if jnp.isrealobj(Y):
return jnp.signbit(Y).astype(Y.dtype) * -2 + 1
Y = lax.select(Y != 0, Y, jnp.ones_like(Y))
return Y / jnp.abs(Y)
class ResampleSLoopState(NamedTuple):
S: Array
s: Array
parallel_flag: Array
def resample_S1(S, s):
"""
Ensure that no column of S is parallel to another column of S
"""
m = S.shape[0]
def body_fun(state: ResampleSLoopState):
S, s = resample_S_core(state)
parallel_flag = jnp.any(jnp.triu(jnp.abs(S.T @ S) == m, k=1), axis=0)
return ResampleSLoopState(S, s, parallel_flag)
parallel_flag = jnp.any(jnp.triu(jnp.abs(S.T @ S) == m, k=1), axis=0)
init_state = ResampleSLoopState(S, s, parallel_flag)
final_state = jax.lax.while_loop(resample_S_cond_fun, body_fun, init_state)
return final_state.S, final_state.s
def resample_S2(S, s, S_old, parallel_to_S_old):
"""
Ensure that no column of S is parallel to another column of S_old
"""
m = S.shape[0]
def body_fun(state: ResampleSLoopState):
S, s = resample_S_core(state)
parallel_flag = jnp.any(jnp.abs(S_old.T @ S) == m, axis=0)
return ResampleSLoopState(S, s, parallel_flag)
init_state = ResampleSLoopState(S, s, parallel_to_S_old)
final_state = jax.lax.while_loop(resample_S_cond_fun, body_fun, init_state)
return final_state.S, final_state.s
def resample_S_cond_fun(state: ResampleSLoopState):
return jnp.any(state.parallel_flag)
def resample_S_core(state: ResampleSLoopState):
S, s, parallel_flag = state
# x[parallel_flag] == s + 1, s + 2, ...
x = jnp.cumsum(parallel_flag.astype(S.dtype).at[0].add(s))
update_signbit = jnp.arange(S.shape[0])[:, None] % (2 * x) >= x
update = update_signbit.astype(S.dtype) * -2 + 1
S = lax.select(jnp.full(S.shape, parallel_flag), update, S)
return S, x[-1]
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment