Last active
September 7, 2025 17:37
-
-
Save anish-lakkapragada/9f39da7f072e88d98cd56517e0193ce6 to your computer and use it in GitHub Desktop.
Zach Furman's Singular Learning Theory Exercise #12 Code
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
| # %% | |
| import numpy as np | |
| import matplotlib.pyplot as plt | |
| import math | |
| # Reproducibility | |
| rng = np.random.default_rng(0) | |
| MU_0 = 5 | |
| BETA = 1 | |
| LLC = 1/2 | |
| LOG2PI = math.log(2 * math.pi) | |
| def log_pdf_sum_stats(N, S1, S2, mu): | |
| # x ~ N(mu^3, 1); Σ log p(x_i|mu) via sufficient stats | |
| m3 = mu**3 | |
| ssr = S2 - 2.0*m3*S1 + N*(m3**2) # Σ (x_i - m3)^2 | |
| return -0.5*ssr - 0.5*N*LOG2PI | |
| def logsumexp(a): | |
| a = np.asarray(a, dtype=float) | |
| m = np.max(a) | |
| return m + np.log(np.sum(np.exp(a - m))) | |
| def log_trapz_exp(f_vals, xs): | |
| # log ∫ exp(f(x)) dx by trapezoid in log-space | |
| xs = np.asarray(xs, float) | |
| f_vals = np.asarray(f_vals, float) | |
| dx = xs[1:] - xs[:-1] | |
| pair_log = np.log(dx) + np.logaddexp(f_vals[:-1], f_vals[1:]) - math.log(2.0) | |
| return logsumexp(pair_log) | |
| def numeric_free_energy_stable(samples, mus, beta=1.0): | |
| samples = np.asarray(samples, float) | |
| N = samples.size | |
| S1 = float(np.sum(samples)) | |
| S2 = float(np.sum(samples**2)) | |
| f_vals = beta * np.array([log_pdf_sum_stats(N, S1, S2, mu) for mu in mus], float) | |
| log_integral = log_trapz_exp(f_vals, mus) | |
| return -log_integral | |
| def empirical_energy(samples): | |
| return - np.mean(-0.5*(samples - MU_0**3)**2 - 0.5*LOG2PI) | |
| def estimate_free_energy(samples, llc, beta=1.0): | |
| n = len(samples) | |
| return n*beta*empirical_energy(samples) + llc*np.log(n) | |
| def make_mu_grid(samples, M=2001, k=8, max_halfwidth=0.5): | |
| """ | |
| Adaptive μ-grid centered at μ̂ ≈ cbrt(mean x). | |
| Posterior sd(μ) ≈ 1 / sqrt(9 N μ̂^4); grid spans ±k·sd(μ). | |
| """ | |
| xbar = float(np.mean(samples)) | |
| mu_hat = float(np.cbrt(xbar)) # real cube root | |
| denom = max(9.0 * len(samples) * (mu_hat**4), 1e-12) | |
| sd_mu = 1.0 / math.sqrt(denom) | |
| halfwidth = min(max_halfwidth, k * sd_mu + (k*sd_mu)/M) # small padding | |
| mus = np.linspace(mu_hat - halfwidth, mu_hat + halfwidth, M) | |
| return mus, mu_hat, sd_mu | |
| # ---- experiment & plot ---- | |
| Ns = np.logspace(2, 6, num=20, dtype=int) # 1e2 ... 1e6 | |
| numeric_energies, estimated_energies = [], [] | |
| for N in Ns: | |
| samples = rng.normal(loc=MU_0**3, scale=1.0, size=N) | |
| mus, mu_hat, sd_mu = make_mu_grid(samples, M=2001, k=8, max_halfwidth=0.5) | |
| Fn_num = numeric_free_energy_stable(samples, mus, beta=BETA) | |
| Fn_est = estimate_free_energy(samples, LLC, beta=BETA) | |
| numeric_energies.append(Fn_num) | |
| estimated_energies.append(Fn_est) | |
| print(f"N={N}, mu_hat≈{mu_hat:.6f}, sd_mu≈{sd_mu:.2e}, " | |
| f"Numeric F_n={Fn_num:.6f}, Estimated F_n={Fn_est:.6f}") | |
| plt.figure(figsize=(10, 6)) | |
| plt.plot(Ns, numeric_energies, marker='o', label=r'Numeric $F_n$') | |
| plt.plot(Ns, estimated_energies, marker='x', | |
| label=r'Estimated $F_n \approx n\beta S_n + \lambda\log n$') | |
| plt.xscale('log') | |
| plt.yscale('log') | |
| plt.xlabel(r'$N$') | |
| plt.ylabel('Free Energy') | |
| plt.title(rf'Numerically Computed vs. Estimated $F_n$ for $\beta = 1, \mu_0 = {MU_0}, \lambda = {round(LLC, 3)}$') | |
| plt.legend() | |
| plt.grid(True) | |
| plt.show() | |
| # %% |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment