Skip to content

Instantly share code, notes, and snippets.

@pgtwitter
Created February 1, 2026 01:43
Show Gist options
  • Select an option

  • Save pgtwitter/74f51a7cf66edf27d79d8f63c5c189c4 to your computer and use it in GitHub Desktop.

Select an option

Save pgtwitter/74f51a7cf66edf27d79d8f63c5c189c4 to your computer and use it in GitHub Desktop.
SiLU関数のTensorLens方式,粗い1次近似,sigmoidベースの近似の比較
# %%
import torch
import torch.nn.functional as F
import numpy as np
import matplotlib.pyplot as plt
# ────────────────────────────────────────────────
# データ準備
# ────────────────────────────────────────────────
x_np = np.linspace(-4, 4, 401) # [-4, 4] の範囲で細かく
x = torch.from_numpy(x_np).float().unsqueeze(0) # shape: (1, 401)
# 真のSiLU
true_silu = F.silu(x).squeeze(0).numpy()
# 1. TensorLens方式(入力依存の対角テンソル Ψ でexact再構成)
mask = x.abs() > 1e-6
psi_diag = torch.where(
mask,
F.silu(x) / x,
torch.tensor(0.5) # SiLU'(0) = sigmoid(0) = 0.5
)
reconstructed_exact = (psi_diag * x).squeeze(0).numpy()
# 2. 粗い1次近似(線形領域の傾き ≈ 0.5)
approx_linear = 0.5 * x.squeeze(0).numpy()
# 3. sigmoidベースの近似(SiLUの数学的定義そのものだが浮動小数点で微差が出る例)
approx_sigmoid = (x * torch.sigmoid(x)).squeeze(0).numpy()
# ────────────────────────────────────────────────
# 誤差の確認(参考出力)
# ────────────────────────────────────────────────
print("Max absolute error:")
print(f" Ψ (exact) : {np.abs(true_silu - reconstructed_exact).max():.2e}")
print(f" Linear (0.5x) : {np.abs(true_silu - approx_linear).max():.2e}")
print(f" Sigmoid definition : {np.abs(true_silu - approx_sigmoid).max():.2e}")
# ────────────────────────────────────────────────
# プロット
# ────────────────────────────────────────────────
fig, (ax1, ax2) = plt.subplots(2, 1, figsize=(10, 8.5), sharex=True)
# 上段:関数値の比較
ax1.plot(x_np, true_silu, 'k-', lw=2.5, label='True SiLU')
ax1.plot(x_np, reconstructed_exact, '--', color='limegreen', lw=2.2,
label='TensorLens Ψ (exact)')
ax1.plot(x_np, approx_linear, '-.', color='orange', lw=1.8,
label='Linear approx (0.5x)')
ax1.plot(x_np, approx_sigmoid, ':', color='purple', lw=2.1,
label='x · sigmoid(x) approx')
ax1.set_title('SiLU and its approximations (x ∈ [-4, 4])', fontsize=13)
ax1.set_ylabel('Value', fontsize=11)
ax1.legend(fontsize=10, loc='upper left')
ax1.grid(True, alpha=0.3)
ax1.set_ylim(-0.1, 4.1)
# 下段:絶対誤差(logスケール)
ax2.plot(x_np, np.abs(true_silu - reconstructed_exact) + 1e-10,
'--', color='limegreen', lw=2.2, label='Ψ error')
ax2.plot(x_np, np.abs(true_silu - approx_linear) + 1e-10,
'-.', color='orange', lw=1.8, label='Linear error')
ax2.plot(x_np, np.abs(true_silu - approx_sigmoid) + 1e-10,
':', color='purple', lw=2.1, label='Sigmoid approx error')
ax2.set_yscale('log')
ax2.set_xlabel('x', fontsize=11)
ax2.set_ylabel('Absolute error (log scale)', fontsize=11)
ax2.set_title('Approximation errors — log scale', fontsize=13)
ax2.legend(fontsize=10, loc='upper right')
ax2.grid(True, alpha=0.3)
plt.tight_layout()
plt.show()
@pgtwitter
Copy link
Author

b

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment