Skip to content

Instantly share code, notes, and snippets.

@anish-lakkapragada
Last active September 7, 2025 17:37
Show Gist options
  • Select an option

  • Save anish-lakkapragada/9f39da7f072e88d98cd56517e0193ce6 to your computer and use it in GitHub Desktop.

Select an option

Save anish-lakkapragada/9f39da7f072e88d98cd56517e0193ce6 to your computer and use it in GitHub Desktop.
Zach Furman's Singular Learning Theory Exercise #12 Code
# %%
import numpy as np
import matplotlib.pyplot as plt
import math
# Reproducibility
rng = np.random.default_rng(0)
MU_0 = 5
BETA = 1
LLC = 1/2
LOG2PI = math.log(2 * math.pi)
def log_pdf_sum_stats(N, S1, S2, mu):
# x ~ N(mu^3, 1); Σ log p(x_i|mu) via sufficient stats
m3 = mu**3
ssr = S2 - 2.0*m3*S1 + N*(m3**2) # Σ (x_i - m3)^2
return -0.5*ssr - 0.5*N*LOG2PI
def logsumexp(a):
a = np.asarray(a, dtype=float)
m = np.max(a)
return m + np.log(np.sum(np.exp(a - m)))
def log_trapz_exp(f_vals, xs):
# log ∫ exp(f(x)) dx by trapezoid in log-space
xs = np.asarray(xs, float)
f_vals = np.asarray(f_vals, float)
dx = xs[1:] - xs[:-1]
pair_log = np.log(dx) + np.logaddexp(f_vals[:-1], f_vals[1:]) - math.log(2.0)
return logsumexp(pair_log)
def numeric_free_energy_stable(samples, mus, beta=1.0):
samples = np.asarray(samples, float)
N = samples.size
S1 = float(np.sum(samples))
S2 = float(np.sum(samples**2))
f_vals = beta * np.array([log_pdf_sum_stats(N, S1, S2, mu) for mu in mus], float)
log_integral = log_trapz_exp(f_vals, mus)
return -log_integral
def empirical_energy(samples):
return - np.mean(-0.5*(samples - MU_0**3)**2 - 0.5*LOG2PI)
def estimate_free_energy(samples, llc, beta=1.0):
n = len(samples)
return n*beta*empirical_energy(samples) + llc*np.log(n)
def make_mu_grid(samples, M=2001, k=8, max_halfwidth=0.5):
"""
Adaptive μ-grid centered at μ̂ ≈ cbrt(mean x).
Posterior sd(μ) ≈ 1 / sqrt(9 N μ̂^4); grid spans ±k·sd(μ).
"""
xbar = float(np.mean(samples))
mu_hat = float(np.cbrt(xbar)) # real cube root
denom = max(9.0 * len(samples) * (mu_hat**4), 1e-12)
sd_mu = 1.0 / math.sqrt(denom)
halfwidth = min(max_halfwidth, k * sd_mu + (k*sd_mu)/M) # small padding
mus = np.linspace(mu_hat - halfwidth, mu_hat + halfwidth, M)
return mus, mu_hat, sd_mu
# ---- experiment & plot ----
Ns = np.logspace(2, 6, num=20, dtype=int) # 1e2 ... 1e6
numeric_energies, estimated_energies = [], []
for N in Ns:
samples = rng.normal(loc=MU_0**3, scale=1.0, size=N)
mus, mu_hat, sd_mu = make_mu_grid(samples, M=2001, k=8, max_halfwidth=0.5)
Fn_num = numeric_free_energy_stable(samples, mus, beta=BETA)
Fn_est = estimate_free_energy(samples, LLC, beta=BETA)
numeric_energies.append(Fn_num)
estimated_energies.append(Fn_est)
print(f"N={N}, mu_hat≈{mu_hat:.6f}, sd_mu≈{sd_mu:.2e}, "
f"Numeric F_n={Fn_num:.6f}, Estimated F_n={Fn_est:.6f}")
plt.figure(figsize=(10, 6))
plt.plot(Ns, numeric_energies, marker='o', label=r'Numeric $F_n$')
plt.plot(Ns, estimated_energies, marker='x',
label=r'Estimated $F_n \approx n\beta S_n + \lambda\log n$')
plt.xscale('log')
plt.yscale('log')
plt.xlabel(r'$N$')
plt.ylabel('Free Energy')
plt.title(rf'Numerically Computed vs. Estimated $F_n$ for $\beta = 1, \mu_0 = {MU_0}, \lambda = {round(LLC, 3)}$')
plt.legend()
plt.grid(True)
plt.show()
# %%
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment