Created
December 16, 2025 03:30
-
-
Save mahadirz/46e9b68d38be4363bff0ae29c85be24c to your computer and use it in GitHub Desktop.
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
| """ | |
| This snippet is the toy experiment to numerically verify the math | |
| in the Implicit Dynamics of In-Context Learning paper. | |
| The full article is on medium title:Synaptic Plasticity in Transformer Part 1 | |
| The code has been converted from the original notebook to a python script. | |
| """ | |
| import torch | |
| from transformers import GPT2LMHeadModel, GPT2Tokenizer | |
| #%% | |
| # 1. SETUP | |
| model_name = "gpt2" | |
| model = GPT2LMHeadModel.from_pretrained(model_name) | |
| tokenizer = GPT2Tokenizer.from_pretrained(model_name) | |
| model.eval() | |
| #%% | |
| # Select a specific layer to analyze (e.g., Layer 5) | |
| LAYER_ID = 5 | |
| print(f"--- Experimenting on GPT-2 Layer {LAYER_ID} ---") | |
| # 2. DEFINE INPUTS | |
| # Context: A pattern the model should follow | |
| context_text = "The capital of Malaysia is Kuala Lumpur. The capital of USA is Washington D.C. The capital of Germany is" | |
| # Query: The token we want to predict next | |
| query_text = " Berlin" | |
| # Full prompt | |
| full_text = context_text + query_text | |
| # Tokenize | |
| input_ids_full = tokenizer.encode(full_text, return_tensors='pt') | |
| input_ids_query = tokenizer.encode(query_text, return_tensors='pt') | |
| # 3. HOOKS TO CAPTURE ACTIVATIONS | |
| # We need to capture the input *just before* the MLP's first linear layer (c_fc) | |
| # In GPT-2: Block -> LN_2 -> MLP (c_fc -> act -> c_proj) | |
| # The paper calls the input to the MLP "A(C, x)" | |
| activations = {} | |
| def get_activation(name): | |
| def hook(model, input, output): | |
| # input[0] is the tensor entering the layer | |
| activations[name] = input[0].detach() | |
| return hook | |
| # Register hook on the first linear layer of the MLP in the target block | |
| # accessing model.transformer.h[LAYER_ID].mlp.c_fc | |
| handle = model.transformer.h[LAYER_ID].mlp.c_fc.register_forward_hook(get_activation("mlp_input")) | |
| #%% | |
| input_ids_full.shape, input_ids_query.shape | |
| #%% | |
| # 4. PASS 1: WITH CONTEXT (The "Real" Result) | |
| with torch.no_grad(): | |
| _ = model(input_ids_full) | |
| # The input to the MLP for the *last token* (the query token) | |
| # Shape: [batch, seq_len, hidden_dim] -> we take last token | |
| A_context_x = activations["mlp_input"][0, -1, :].clone() | |
| # 5. PASS 2: WITHOUT CONTEXT (The "Base" Result) | |
| # Note: In a real transformer, position embeddings change if we run query alone. | |
| # To strictly verify the MATH, we need the "No Context" version of the vector | |
| # to be the same position as it was in the sequence, just without the attention history. | |
| # However, for this demo, running the query alone is the closest "semantic" proxy. | |
| with torch.no_grad(): | |
| _ = model(input_ids_query) | |
| A_x = activations["mlp_input"][0, -1, :].clone() | |
| # Remove hook | |
| handle.remove() | |
| #%% | |
| A_context_x.shape, A_x.shape | |
| #%% | |
| # 6. GET THE WEIGHTS | |
| # GPT-2 uses Conv1D, so weights are transposed compared to standard Linear | |
| # shape is (hidden_size, intermediate_size) | |
| W = model.transformer.h[LAYER_ID].mlp.c_fc.weight.detach() # The 'W' in the paper | |
| b = model.transformer.h[LAYER_ID].mlp.c_fc.bias.detach() # The 'b' in the paper | |
| # 7. PERFORM THE PAPER'S MATH | |
| print("\n--- Calculating Implicit Dynamics ---") | |
| # Calculate Delta A (Difference in activation caused by context) | |
| # ΔA = A(C, x) - A(x) | |
| delta_A = A_context_x - A_x | |
| # Paper Formula for ΔW (Theorem 2.2 / Corollary 2.3.1): | |
| # ΔW = ( (W * ΔA) * A(x)^T ) / ||A(x)||^2 | |
| # Note: The paper uses column vectors (Mx1). PyTorch uses row vectors (1xM). | |
| # We need to be careful with shapes. | |
| # 1. Compute W * ΔA | |
| # In PyTorch (Row vectors): delta_A @ W | |
| W_times_deltaA = torch.matmul(delta_A, W) # Shape: [intermediate_size] | |
| # 2. Compute Norm squared of A(x) | |
| norm_Ax_sq = torch.norm(A_x)**2 | |
| # 3. Compute the Rank-1 Update Matrix | |
| # Outer product of (W * ΔA) and A(x) | |
| # In paper: Column @ Row. | |
| # Here we want to add to W which is (hidden, intermediate). | |
| # We effectively want: A(x).T @ (W_times_deltaA) scaled by norm | |
| # Let's look at the forward pass equation in PyTorch: y = x @ W + b | |
| # We want: x @ (W + ΔW) = x @ W + x @ ΔW | |
| # x @ ΔW must equal W_times_deltaA (which is the effect of context processed by W) | |
| # The scalar projection logic: | |
| # We want ΔW such that: A_x @ ΔW = W_times_deltaA | |
| # Using the paper's projection trick: | |
| # ΔW = (A_x.T @ W_times_deltaA) / norm_Ax_sq | |
| # A_x: [768] | |
| # W_times_deltaA: [3072] | |
| # Outer Product: [768, 3072] -> Matches W shape | |
| Delta_W = torch.outer(A_x, W_times_deltaA) / norm_Ax_sq | |
| print(f"Weight Shape: {W.shape}") | |
| print(f"Delta W Shape: {Delta_W.shape} (Rank 1 Matrix)") | |
| #%% | |
| # 8. VERIFICATION | |
| # A) Real Output (What happened in the full context pass) | |
| # In PyTorch GPT-2: input @ W + b | |
| real_output = torch.matmul(A_context_x, W) + b | |
| # B) Theory Output (Input WITHOUT context, but weights WITH update) | |
| # We use A_x (no context input) and (W + Delta_W) | |
| modified_W = W + Delta_W | |
| theory_output = torch.matmul(A_x, modified_W) + b | |
| # 9. COMPARE | |
| diff = torch.norm(real_output - theory_output).item() | |
| magnitude = torch.norm(real_output).item() | |
| print("\n--- Results ---") | |
| print(f"Magnitude of actual output: {magnitude:.4f}") | |
| print(f"Difference (Euclidean Dist): {diff:.6f}") | |
| print(f"Relative Error: {(diff/magnitude)*100:.6f}%") | |
| if diff < 1e-4: | |
| print("\n✅ SUCCESS: The algebraic substitution holds.") | |
| print("The model output with context is IDENTICAL to the model output without context but with updated weights.") | |
| else: | |
| print("\n❌ FAIL: Large discrepancy.") | |
| #%% |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment