Created
January 2, 2026 01:31
-
-
Save howsiyu/493d45ba513a61dc609ca67c41f7e29e to your computer and use it in GitHub Desktop.
Implement scipy.linalg.onenormest in idiomatic JaX
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
| from functools import partial | |
| from typing import Any, NamedTuple, Protocol | |
| import jax | |
| import jax.numpy as jnp | |
| from jax import lax | |
| from jax import Array | |
| class LinearOperator(Protocol): | |
| dtype: Any | |
| shape: tuple[int, int] | |
| def matmat(self, X: Array) -> Array: | |
| """ | |
| Performs the operation A @ X. | |
| """ | |
| ... | |
| def rmatmat(self, X: Array) -> Array: | |
| """ | |
| Performs the operation A.H @ X where A.H denotes the conjugate tranpose of A. | |
| """ | |
| ... | |
| def cols(self, indices: Array) -> Array: | |
| """ | |
| Computes A[:, indices]. This can often be performed faster than using matmat. | |
| """ | |
| ... | |
| class OnenormestLoopState(NamedTuple): | |
| break_flag: Array | |
| # Number of loop iterations | |
| k: Array | |
| # Number of parallel columns of S we have detected so far. | |
| # To avoid randomness, we replace columns of S that are parallel to S_old or another | |
| # column of S by columns that alternate signs every i columns for i = 1, 2, ... | |
| # This should be good enough as replacing parallel columns of S has little effect | |
| # on most matrices anyway (and is not done for complex matrices). | |
| s: Array | |
| S: Array | |
| # Boolean vector recording indices of used unit vectors e_j | |
| ind_hist: Array | |
| # Indices of the t largest columns we have encountered so far. | |
| best_t_indices: Array | |
| # 1-norms of columns in best_t_indices | |
| best_t_estimates: Array | |
| @partial(jax.jit, static_argnums=[2, 4, 5]) | |
| def onenormest( | |
| key, | |
| A: LinearOperator, | |
| t: int, | |
| itmax=5, | |
| compute_v: bool = False, | |
| compute_w: bool = False, | |
| ): | |
| m, n = A.shape | |
| if 4 * t >= n: | |
| # If the operator size is small compared to t, | |
| # then it is easier to compute the exact norm. | |
| A_explicit = A.cols(jnp.arange(n)) | |
| A_abs = lax.abs(A_explicit).sum(axis=0) | |
| j = jnp.argmax(A_abs) | |
| est = A_abs[j] | |
| else: | |
| if jnp.issubdtype(A.dtype, jnp.complexfloating): | |
| real_dtype = jnp.zeros((), dtype=A.dtype).real.dtype | |
| u = jax.random.uniform(key, shape=(m, t), dtype=real_dtype) | |
| S = jnp.exp(u * (jnp.pi * 2j)) | |
| else: | |
| S = jax.random.bernoulli(key, shape=(m, t)).astype(A.dtype) * -2 + 1 | |
| # My experiment shows this performs similarly to the starting matrix with | |
| # ones and rand{−1, 1}s as suggested by the paper on random matrices but | |
| # vasly outperform the latter for the pathological triangular matrix | |
| # example given in the paper. | |
| starting_matrix = sign_round_up(A.rmatmat(S)) | |
| final_state = onenormest_core(A, starting_matrix, jnp.array(itmax)) | |
| j = final_state.best_t_indices[0] | |
| est = final_state.best_t_estimates[0] | |
| # Report the norm estimate along with some certificates of the estimate. | |
| if compute_v or compute_w: | |
| result = (est,) | |
| if compute_v: | |
| v = jnp.zeros(n, A.dtype).at[j].set(1) | |
| result += (v,) | |
| if compute_w: | |
| result += (A.cols(j),) | |
| return result | |
| else: | |
| return est | |
| def onenormest_core( | |
| A: LinearOperator, starting_matrix: Array, itmax: Array | |
| ) -> OnenormestLoopState: | |
| """ | |
| A deterministic variant of Algorithm 2.4 in | |
| Nicholas J. Higham and Francoise Tisseur (2000), | |
| "A Block Algorithm for Matrix 1-Norm Estimation, | |
| with an Application to 1-Norm Pseudospectra." | |
| that given a starting_matrix, returns t column indices | |
| with their norms that estimate the largest t columns. | |
| """ | |
| m, n = A.shape | |
| n2, t = starting_matrix.shape | |
| if n != n2: | |
| raise ValueError("A and starting_matrix must have compatible shapes!") | |
| if t < 1: | |
| raise ValueError("at least one column is required") | |
| if 4 * t >= n: | |
| raise ValueError("4 * t should be smaller than the order of A") | |
| # If (itmax + 1) * t > n we could have s >= n which makes us loop | |
| # indefinitely when resampling S's columns. In practice this algorithm | |
| # only makes sense if n >> itmax * t so this shouldn't happen. | |
| itmax = jnp.clip(itmax, min=2, max=n // t - 1) | |
| Y = A.matmat(starting_matrix) | |
| init_state = OnenormestLoopState( | |
| break_flag=jnp.array(False), | |
| k=jnp.array(1), | |
| s=jnp.array(0), | |
| S=sign_round_up(Y), | |
| ind_hist=jnp.zeros(n, dtype=jnp.bool), | |
| best_t_indices=jnp.zeros(t, dtype=int), | |
| best_t_estimates=jnp.full(t, -jnp.abs(jnp.array(1, A.dtype))), | |
| ) | |
| def cond_fun(state: OnenormestLoopState): | |
| return ~state.break_flag | |
| def body_fun(state: OnenormestLoopState) -> OnenormestLoopState: | |
| S = state.S | |
| s = state.s | |
| k = state.k | |
| if not jnp.issubdtype(A.dtype, jnp.complexfloating): | |
| S, s = resample_S1(S, s) | |
| h = jnp.abs(A.rmatmat(S)).max(axis=1) # (3) in the paper | |
| ind_tmp = jnp.argsort(h, descending=True) | |
| (best_t_indices_not_in_ind_hist,) = jnp.nonzero( | |
| ~state.ind_hist[ind_tmp], size=t | |
| ) | |
| break_flag = (k >= 2) & ( | |
| (h[ind_tmp[0]] == h[state.best_t_indices[0]]) # break condition in (4) | |
| | (best_t_indices_not_in_ind_hist[0] >= t) # break condition in (5) | |
| ) | |
| ind = ind_tmp[best_t_indices_not_in_ind_hist] | |
| ind_hist = state.ind_hist.at[ind].set(True) | |
| Y = A.cols(ind) | |
| Y_abs = lax.abs(Y).sum(axis=0) | |
| est_old = state.best_t_estimates[0] | |
| break_flag = break_flag | (Y_abs.max() <= est_old) # break condition in (1) | |
| best_ests = jnp.append(state.best_t_estimates, Y_abs) | |
| best_t_estimates, best_ests_sorted_indices = lax.top_k(best_ests, t) | |
| best_t_indices = jnp.append(state.best_t_indices, ind)[best_ests_sorted_indices] | |
| S_new = sign_round_up(Y) | |
| if not jnp.issubdtype(A.dtype, jnp.complexfloating): | |
| parallel_to_S = jnp.any(jnp.abs(S.T @ S_new) == m, axis=0) | |
| break_flag = break_flag | jnp.all(parallel_to_S) # break condition in (2) | |
| S_new, s = resample_S2(S_new, s, S, parallel_to_S) | |
| return OnenormestLoopState( | |
| break_flag=break_flag | (k >= itmax), | |
| k=k + 1, | |
| s=s, | |
| S=S_new, | |
| ind_hist=ind_hist, | |
| best_t_indices=best_t_indices, | |
| best_t_estimates=best_t_estimates, | |
| ) | |
| return lax.while_loop(cond_fun, body_fun, init_state) | |
| def sign_round_up(Y: Array): | |
| if jnp.isrealobj(Y): | |
| return jnp.signbit(Y).astype(Y.dtype) * -2 + 1 | |
| Y = lax.select(Y != 0, Y, jnp.ones_like(Y)) | |
| return Y / jnp.abs(Y) | |
| class ResampleSLoopState(NamedTuple): | |
| S: Array | |
| s: Array | |
| parallel_flag: Array | |
| def resample_S1(S, s): | |
| """ | |
| Ensure that no column of S is parallel to another column of S | |
| """ | |
| m = S.shape[0] | |
| def body_fun(state: ResampleSLoopState): | |
| S, s = resample_S_core(state) | |
| parallel_flag = jnp.any(jnp.triu(jnp.abs(S.T @ S) == m, k=1), axis=0) | |
| return ResampleSLoopState(S, s, parallel_flag) | |
| parallel_flag = jnp.any(jnp.triu(jnp.abs(S.T @ S) == m, k=1), axis=0) | |
| init_state = ResampleSLoopState(S, s, parallel_flag) | |
| final_state = jax.lax.while_loop(resample_S_cond_fun, body_fun, init_state) | |
| return final_state.S, final_state.s | |
| def resample_S2(S, s, S_old, parallel_to_S_old): | |
| """ | |
| Ensure that no column of S is parallel to another column of S_old | |
| """ | |
| m = S.shape[0] | |
| def body_fun(state: ResampleSLoopState): | |
| S, s = resample_S_core(state) | |
| parallel_flag = jnp.any(jnp.abs(S_old.T @ S) == m, axis=0) | |
| return ResampleSLoopState(S, s, parallel_flag) | |
| init_state = ResampleSLoopState(S, s, parallel_to_S_old) | |
| final_state = jax.lax.while_loop(resample_S_cond_fun, body_fun, init_state) | |
| return final_state.S, final_state.s | |
| def resample_S_cond_fun(state: ResampleSLoopState): | |
| return jnp.any(state.parallel_flag) | |
| def resample_S_core(state: ResampleSLoopState): | |
| S, s, parallel_flag = state | |
| # x[parallel_flag] == s + 1, s + 2, ... | |
| x = jnp.cumsum(parallel_flag.astype(S.dtype).at[0].add(s)) | |
| update_signbit = jnp.arange(S.shape[0])[:, None] % (2 * x) >= x | |
| update = update_signbit.astype(S.dtype) * -2 + 1 | |
| S = lax.select(jnp.full(S.shape, parallel_flag), update, S) | |
| return S, x[-1] |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment