Skip to content

Instantly share code, notes, and snippets.

@pgtwitter
Created February 1, 2026 01:42
Show Gist options
  • Select an option

  • Save pgtwitter/08b16f123bd17933acafc81f79821aa9 to your computer and use it in GitHub Desktop.

Select an option

Save pgtwitter/08b16f123bd17933acafc81f79821aa9 to your computer and use it in GitHub Desktop.
GELU関数のTensorLens方式,粗い1次近似,簡易3次近似の比較
# %%
import torch
import torch.nn.functional as F
import numpy as np
import matplotlib.pyplot as plt
# ────────────────────────────────────────────────
# データ準備
# ────────────────────────────────────────────────
x_np = np.linspace(-4, 4, 401) # 細かくサンプリング
x = torch.from_numpy(x_np).float().unsqueeze(0) # (1, 401)
# 真のGELU
true_gelu = F.gelu(x).squeeze(0).numpy()
# 1. TensorLens方式(入力依存の対角テンソル Ψ でexact再構成)
mask = x.abs() > 1e-6
psi_diag = torch.where(
mask,
F.gelu(x) / x,
torch.tensor(0.3989422804014327) # GELU'(0) ≈ √(2/π)/2 ≈ 0.3989
)
reconstructed_exact = (psi_diag * x).squeeze(0).numpy()
# 2. 粗い1次近似(よく「線形近似」として使われるレベル)
approx_linear = 0.5 * x.squeeze(0).numpy()
# 3. 簡易3次近似(よく引用されるtanhベースの近似式)
def gelu_cubic_approx(z):
# 簡易的な3次近似(Hermite展開の低次版に近い)
return 0.5 * z * (1 + torch.tanh(np.sqrt(2 / np.pi) * (z + 0.044715 * z**3)))
approx_cubic = gelu_cubic_approx(x).squeeze(0).numpy()
# ────────────────────────────────────────────────
# 誤差計算(参考表示)
# ────────────────────────────────────────────────
print(f"Max absolute error:")
print(f" Ψ (exact) : {np.abs(true_gelu - reconstructed_exact).max():.2e}")
print(f" Linear : {np.abs(true_gelu - approx_linear).max():.2e}")
print(f" Cubic approx: {np.abs(true_gelu - approx_cubic).max():.2e}")
# ────────────────────────────────────────────────
# プロット
# ────────────────────────────────────────────────
fig, (ax1, ax2) = plt.subplots(2, 1, figsize=(10, 8.5), sharex=True)
# 上段:関数値の比較
ax1.plot(x_np, true_gelu, 'k-', lw=2.5, label='True GELU')
ax1.plot(x_np, reconstructed_exact, '--', color='limegreen', lw=2.2,
label='TensorLens Ψ (exact)')
ax1.plot(x_np, approx_linear, '-.', color='orange', lw=1.8,
label='Linear approx (0.5x)')
ax1.plot(x_np, approx_cubic, ':', color='purple', lw=2.1,
label='Cubic-like approx')
ax1.set_title('GELU and its approximations (x ∈ [-4, 4])', fontsize=13)
ax1.set_ylabel('Value', fontsize=11)
ax1.legend(fontsize=10, loc='upper left')
ax1.grid(True, alpha=0.3)
ax1.set_ylim(-0.1, 4.1)
# 下段:絶対誤差(logスケール)
ax2.plot(x_np, np.abs(true_gelu - reconstructed_exact) + 1e-10,
'--', color='limegreen', lw=2.2, label='TensorLens Ψ error')
ax2.plot(x_np, np.abs(true_gelu - approx_linear) + 1e-10,
'-.', color='orange', lw=1.8, label='Linear error')
ax2.plot(x_np, np.abs(true_gelu - approx_cubic) + 1e-10,
':', color='purple', lw=2.1, label='Cubic approx error')
ax2.set_yscale('log')
ax2.set_xlabel('x', fontsize=11)
ax2.set_ylabel('Absolute error (log scale)', fontsize=11)
ax2.set_title('Approximation errors — log scale', fontsize=13)
ax2.legend(fontsize=10, loc='upper right')
ax2.grid(True, alpha=0.3)
plt.tight_layout()
plt.show()
@pgtwitter
Copy link
Author

a

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment