-
-
Save pgtwitter/08b16f123bd17933acafc81f79821aa9 to your computer and use it in GitHub Desktop.
GELU関数のTensorLens方式,粗い1次近似,簡易3次近似の比較
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
| # %% | |
| import torch | |
| import torch.nn.functional as F | |
| import numpy as np | |
| import matplotlib.pyplot as plt | |
| # ──────────────────────────────────────────────── | |
| # データ準備 | |
| # ──────────────────────────────────────────────── | |
| x_np = np.linspace(-4, 4, 401) # 細かくサンプリング | |
| x = torch.from_numpy(x_np).float().unsqueeze(0) # (1, 401) | |
| # 真のGELU | |
| true_gelu = F.gelu(x).squeeze(0).numpy() | |
| # 1. TensorLens方式(入力依存の対角テンソル Ψ でexact再構成) | |
| mask = x.abs() > 1e-6 | |
| psi_diag = torch.where( | |
| mask, | |
| F.gelu(x) / x, | |
| torch.tensor(0.3989422804014327) # GELU'(0) ≈ √(2/π)/2 ≈ 0.3989 | |
| ) | |
| reconstructed_exact = (psi_diag * x).squeeze(0).numpy() | |
| # 2. 粗い1次近似(よく「線形近似」として使われるレベル) | |
| approx_linear = 0.5 * x.squeeze(0).numpy() | |
| # 3. 簡易3次近似(よく引用されるtanhベースの近似式) | |
| def gelu_cubic_approx(z): | |
| # 簡易的な3次近似(Hermite展開の低次版に近い) | |
| return 0.5 * z * (1 + torch.tanh(np.sqrt(2 / np.pi) * (z + 0.044715 * z**3))) | |
| approx_cubic = gelu_cubic_approx(x).squeeze(0).numpy() | |
| # ──────────────────────────────────────────────── | |
| # 誤差計算(参考表示) | |
| # ──────────────────────────────────────────────── | |
| print(f"Max absolute error:") | |
| print(f" Ψ (exact) : {np.abs(true_gelu - reconstructed_exact).max():.2e}") | |
| print(f" Linear : {np.abs(true_gelu - approx_linear).max():.2e}") | |
| print(f" Cubic approx: {np.abs(true_gelu - approx_cubic).max():.2e}") | |
| # ──────────────────────────────────────────────── | |
| # プロット | |
| # ──────────────────────────────────────────────── | |
| fig, (ax1, ax2) = plt.subplots(2, 1, figsize=(10, 8.5), sharex=True) | |
| # 上段:関数値の比較 | |
| ax1.plot(x_np, true_gelu, 'k-', lw=2.5, label='True GELU') | |
| ax1.plot(x_np, reconstructed_exact, '--', color='limegreen', lw=2.2, | |
| label='TensorLens Ψ (exact)') | |
| ax1.plot(x_np, approx_linear, '-.', color='orange', lw=1.8, | |
| label='Linear approx (0.5x)') | |
| ax1.plot(x_np, approx_cubic, ':', color='purple', lw=2.1, | |
| label='Cubic-like approx') | |
| ax1.set_title('GELU and its approximations (x ∈ [-4, 4])', fontsize=13) | |
| ax1.set_ylabel('Value', fontsize=11) | |
| ax1.legend(fontsize=10, loc='upper left') | |
| ax1.grid(True, alpha=0.3) | |
| ax1.set_ylim(-0.1, 4.1) | |
| # 下段:絶対誤差(logスケール) | |
| ax2.plot(x_np, np.abs(true_gelu - reconstructed_exact) + 1e-10, | |
| '--', color='limegreen', lw=2.2, label='TensorLens Ψ error') | |
| ax2.plot(x_np, np.abs(true_gelu - approx_linear) + 1e-10, | |
| '-.', color='orange', lw=1.8, label='Linear error') | |
| ax2.plot(x_np, np.abs(true_gelu - approx_cubic) + 1e-10, | |
| ':', color='purple', lw=2.1, label='Cubic approx error') | |
| ax2.set_yscale('log') | |
| ax2.set_xlabel('x', fontsize=11) | |
| ax2.set_ylabel('Absolute error (log scale)', fontsize=11) | |
| ax2.set_title('Approximation errors — log scale', fontsize=13) | |
| ax2.legend(fontsize=10, loc='upper right') | |
| ax2.grid(True, alpha=0.3) | |
| plt.tight_layout() | |
| plt.show() |
Author
pgtwitter
commented
Feb 1, 2026
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment