Skip to content

Instantly share code, notes, and snippets.

@mahadirz
Created December 16, 2025 03:30
Show Gist options
  • Select an option

  • Save mahadirz/46e9b68d38be4363bff0ae29c85be24c to your computer and use it in GitHub Desktop.

Select an option

Save mahadirz/46e9b68d38be4363bff0ae29c85be24c to your computer and use it in GitHub Desktop.
"""
This snippet is the toy experiment to numerically verify the math
in the Implicit Dynamics of In-Context Learning paper.
The full article is on medium title:Synaptic Plasticity in Transformer Part 1
The code has been converted from the original notebook to a python script.
"""
import torch
from transformers import GPT2LMHeadModel, GPT2Tokenizer
#%%
# 1. SETUP
model_name = "gpt2"
model = GPT2LMHeadModel.from_pretrained(model_name)
tokenizer = GPT2Tokenizer.from_pretrained(model_name)
model.eval()
#%%
# Select a specific layer to analyze (e.g., Layer 5)
LAYER_ID = 5
print(f"--- Experimenting on GPT-2 Layer {LAYER_ID} ---")
# 2. DEFINE INPUTS
# Context: A pattern the model should follow
context_text = "The capital of Malaysia is Kuala Lumpur. The capital of USA is Washington D.C. The capital of Germany is"
# Query: The token we want to predict next
query_text = " Berlin"
# Full prompt
full_text = context_text + query_text
# Tokenize
input_ids_full = tokenizer.encode(full_text, return_tensors='pt')
input_ids_query = tokenizer.encode(query_text, return_tensors='pt')
# 3. HOOKS TO CAPTURE ACTIVATIONS
# We need to capture the input *just before* the MLP's first linear layer (c_fc)
# In GPT-2: Block -> LN_2 -> MLP (c_fc -> act -> c_proj)
# The paper calls the input to the MLP "A(C, x)"
activations = {}
def get_activation(name):
def hook(model, input, output):
# input[0] is the tensor entering the layer
activations[name] = input[0].detach()
return hook
# Register hook on the first linear layer of the MLP in the target block
# accessing model.transformer.h[LAYER_ID].mlp.c_fc
handle = model.transformer.h[LAYER_ID].mlp.c_fc.register_forward_hook(get_activation("mlp_input"))
#%%
input_ids_full.shape, input_ids_query.shape
#%%
# 4. PASS 1: WITH CONTEXT (The "Real" Result)
with torch.no_grad():
_ = model(input_ids_full)
# The input to the MLP for the *last token* (the query token)
# Shape: [batch, seq_len, hidden_dim] -> we take last token
A_context_x = activations["mlp_input"][0, -1, :].clone()
# 5. PASS 2: WITHOUT CONTEXT (The "Base" Result)
# Note: In a real transformer, position embeddings change if we run query alone.
# To strictly verify the MATH, we need the "No Context" version of the vector
# to be the same position as it was in the sequence, just without the attention history.
# However, for this demo, running the query alone is the closest "semantic" proxy.
with torch.no_grad():
_ = model(input_ids_query)
A_x = activations["mlp_input"][0, -1, :].clone()
# Remove hook
handle.remove()
#%%
A_context_x.shape, A_x.shape
#%%
# 6. GET THE WEIGHTS
# GPT-2 uses Conv1D, so weights are transposed compared to standard Linear
# shape is (hidden_size, intermediate_size)
W = model.transformer.h[LAYER_ID].mlp.c_fc.weight.detach() # The 'W' in the paper
b = model.transformer.h[LAYER_ID].mlp.c_fc.bias.detach() # The 'b' in the paper
# 7. PERFORM THE PAPER'S MATH
print("\n--- Calculating Implicit Dynamics ---")
# Calculate Delta A (Difference in activation caused by context)
# ΔA = A(C, x) - A(x)
delta_A = A_context_x - A_x
# Paper Formula for ΔW (Theorem 2.2 / Corollary 2.3.1):
# ΔW = ( (W * ΔA) * A(x)^T ) / ||A(x)||^2
# Note: The paper uses column vectors (Mx1). PyTorch uses row vectors (1xM).
# We need to be careful with shapes.
# 1. Compute W * ΔA
# In PyTorch (Row vectors): delta_A @ W
W_times_deltaA = torch.matmul(delta_A, W) # Shape: [intermediate_size]
# 2. Compute Norm squared of A(x)
norm_Ax_sq = torch.norm(A_x)**2
# 3. Compute the Rank-1 Update Matrix
# Outer product of (W * ΔA) and A(x)
# In paper: Column @ Row.
# Here we want to add to W which is (hidden, intermediate).
# We effectively want: A(x).T @ (W_times_deltaA) scaled by norm
# Let's look at the forward pass equation in PyTorch: y = x @ W + b
# We want: x @ (W + ΔW) = x @ W + x @ ΔW
# x @ ΔW must equal W_times_deltaA (which is the effect of context processed by W)
# The scalar projection logic:
# We want ΔW such that: A_x @ ΔW = W_times_deltaA
# Using the paper's projection trick:
# ΔW = (A_x.T @ W_times_deltaA) / norm_Ax_sq
# A_x: [768]
# W_times_deltaA: [3072]
# Outer Product: [768, 3072] -> Matches W shape
Delta_W = torch.outer(A_x, W_times_deltaA) / norm_Ax_sq
print(f"Weight Shape: {W.shape}")
print(f"Delta W Shape: {Delta_W.shape} (Rank 1 Matrix)")
#%%
# 8. VERIFICATION
# A) Real Output (What happened in the full context pass)
# In PyTorch GPT-2: input @ W + b
real_output = torch.matmul(A_context_x, W) + b
# B) Theory Output (Input WITHOUT context, but weights WITH update)
# We use A_x (no context input) and (W + Delta_W)
modified_W = W + Delta_W
theory_output = torch.matmul(A_x, modified_W) + b
# 9. COMPARE
diff = torch.norm(real_output - theory_output).item()
magnitude = torch.norm(real_output).item()
print("\n--- Results ---")
print(f"Magnitude of actual output: {magnitude:.4f}")
print(f"Difference (Euclidean Dist): {diff:.6f}")
print(f"Relative Error: {(diff/magnitude)*100:.6f}%")
if diff < 1e-4:
print("\n✅ SUCCESS: The algebraic substitution holds.")
print("The model output with context is IDENTICAL to the model output without context but with updated weights.")
else:
print("\n❌ FAIL: Large discrepancy.")
#%%
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment