Skip to content

Instantly share code, notes, and snippets.

@danielvarga
Created May 24, 2025 01:28
Show Gist options
  • Select an option

  • Save danielvarga/4cf118c2c76ca09a2e7f9e781eba6b9f to your computer and use it in GitHub Desktop.

Select an option

Save danielvarga/4cf118c2c76ca09a2e7f9e781eba6b9f to your computer and use it in GitHub Desktop.
Endre Csóka's optimization task over symmetric monotone functions. toy variant.
import torch
import matplotlib.pyplot as plt
from mpl_toolkits.mplot3d import Axes3D # registers the 3-D projection
from matplotlib import cm
# ───────────────────────── optimiser ──────────────────────────────────────
def solve_monotone_symmetric(
n: int,
lr: float = 5e-3,
steps: int = 3000,
λ_mono: float = 1e3,
λ_sym: float = 1e3,
λ_corner: float = 1e3,
# --- new level-set hyper-parameters ---
λ_level: float = 1e3,
num_levels: int = 100,
tau: float = 0.001,
seed: int = 0):
"""
Minimise F[n//2, n//2] over n×n matrices F that are
• coordinate-wise monotone (F[i,j] ≤ F[k,l] if i≤k, j≤l)
• symmetric (F = Fᵀ)
• fixed at the corners (0, 1, 1, 1 in TL, TR, BL, BR)
• soft level-set coverage For each c ∈ linspace(0,1,num_levels),
approx_mean(1{F<c}) ≤ c via sigmoid((c-F)/τ)
All constraints are enforced with quadratic penalties; penalties on
monotone and symmetry terms are normalised by n² for scale stability.
"""
torch.manual_seed(seed)
i = torch.arange(n).unsqueeze(1) # shape (n, 1)
j = torch.arange(n).unsqueeze(0) # shape (1, n)
F_values = (i + j) / (2 * n) # shape (n, n)
F_values += torch.rand(n, n) / 100 # perturbation
F = torch.nn.Parameter(F_values)
opt = torch.optim.Adam([F], lr=lr)
mid = n // 2
# pre-compute the level values on CPU; we move them to F.device in loop
levels = torch.linspace(0.0, 1.0, num_levels)
for iteration in range(steps):
opt.zero_grad()
# ── objective & core penalties ───────────────────────────────────
centre = F[mid, mid]
up = torch.relu(F[:-1, :] - F[1:, :])
right = torch.relu(F[:, :-1] - F[:, 1:])
mono_pen = (up.pow(2).sum() + right.pow(2).sum()) / n ** 2
sym_pen = (F - F.t()).pow(2).sum() / n ** 2
corner_pen = F[0, 0]**2 \
+ (F[0, -1] - 1)**2 \
+ (F[-1, 0] - 1)**2 \
+ (F[-1, -1] - 1)**2
corner_pen *= 0.0
# ── level-set coverage penalty (new) ────────────────────────────
lvl_pen = F.new_tensor(0.)
for c in levels.to(F.device):
coverage = torch.sigmoid((c - F) / tau).mean() # ≈ P(F < c)
'''
if iteration in (1000, 4000, 8000):
true_coverage = (F.detach().numpy() < c.numpy()).mean()
print(c, coverage.detach().numpy(), true_coverage)
'''
lvl_pen += ((coverage - c) / (c + 1)) ** 2
loss = λ_mono * mono_pen + λ_sym * sym_pen + λ_corner * corner_pen + λ_level * lvl_pen
loss.backward()
opt.step()
if iteration % 1000 == 0:
print(iteration,
"loss =", loss.detach().numpy(),
"centre =", centre.detach().numpy(),
"sym_pen =", (λ_sym * sym_pen).detach().numpy(),
"mono_pen =", (λ_mono * mono_pen).detach().numpy(),
"corner_pen =", (λ_corner * corner_pen).detach().numpy(),
"lvl_pen =", (λ_level * lvl_pen).detach().numpy())
show(F.detach(), title=f"{iteration =}")
return F.detach()
# ───────────────────────── visualisation helper ──────────────────────────
def show(F, title: str | None = None, elev: int = 30, azim: int = -60,
cmap: str = "viridis"):
"""Interactive 3-D surface plot of the matrix F, coloured by height."""
n = len(F)
X, Y = torch.meshgrid(torch.arange(n), torch.arange(n), indexing="ij")
Z = F.numpy()
fig = plt.figure(figsize=(6, 5))
ax = fig.add_subplot(111, projection="3d")
# Plot surface coloured according to height (Z)
surf = ax.plot_surface(X.numpy(), Y.numpy(), Z,
cmap=cmap,
linewidth=0,
antialiased=True,
rstride=1, cstride=1)
# Add a colour bar to indicate height scale
fig.colorbar(surf, ax=ax, shrink=0.6, aspect=10, label="height")
ax.view_init(elev=elev, azim=azim)
ax.set_xlabel("i index")
ax.set_ylabel("j index")
ax.set_zlabel("F[i,j]")
ax.set_title(title or f"Monotone symmetric matrix (n = {n})")
plt.tight_layout()
plt.show()
# ───────────────────────── demo block ────────────────────────────────────
if __name__ == "__main__":
n = 100 # grid resolution
F = solve_monotone_symmetric(n)
show(F, title="Surface with level-set penalty and colour by height")
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment