Skip to content

Instantly share code, notes, and snippets.

@rahulvigneswaran
Created December 23, 2025 11:10
Show Gist options
  • Select an option

  • Save rahulvigneswaran/dce94d47596052fd7cc467c149523de4 to your computer and use it in GitHub Desktop.

Select an option

Save rahulvigneswaran/dce94d47596052fd7cc467c149523de4 to your computer and use it in GitHub Desktop.
MLA_RoPE-less.ipynb
Display the source blob
Display the rendered blob
Raw
{
"nbformat": 4,
"nbformat_minor": 0,
"metadata": {
"colab": {
"provenance": [],
"authorship_tag": "ABX9TyPXlRyDqhWFLBLi4mopmTLn",
"include_colab_link": true
},
"kernelspec": {
"name": "python3",
"display_name": "Python 3"
},
"language_info": {
"name": "python"
}
},
"cells": [
{
"cell_type": "markdown",
"metadata": {
"id": "view-in-github",
"colab_type": "text"
},
"source": [
"<a href=\"https://colab.research.google.com/gist/rahulvigneswaran/dce94d47596052fd7cc467c149523de4/mla_rope-less.ipynb\" target=\"_parent\"><img src=\"https://colab.research.google.com/assets/colab-badge.svg\" alt=\"Open In Colab\"/></a>"
]
},
{
"cell_type": "code",
"execution_count": 77,
"metadata": {
"colab": {
"base_uri": "https://localhost:8080/"
},
"id": "RInhMtFHx4pg",
"outputId": "700b6f87-17f2-4bc1-cb6e-a0444f74621b"
},
"outputs": [
{
"output_type": "stream",
"name": "stdout",
"text": [
"Input Shape: (1, 100)\n",
"Latent Query: (3, 1, 20)\n",
"Latent KV Cache: (6, 20)\n",
"Absorbed QK Kernel: (3, 20, 20)\n",
"Absorbed VO Kernel: (3, 20, 100)\n",
"Final Output Shape: (1, 100)\n"
]
}
],
"source": [
"# imports\n",
"import numpy as np\n",
"\n",
"# dims\n",
"d_1 = 100\n",
"d_2 = 20\n",
"d_3 = 60\n",
"h = 3\n",
"\n",
"existing_cache = 5\n",
"\n",
"# inits\n",
"X = np.random.randn(1, d_1)\n",
"\n",
"W_dQ = np.random.randn(d_1, d_2)\n",
"W_dKV = np.random.randn(d_1, d_2)\n",
"\n",
"W_uQ = np.random.randn(d_2, d_3)\n",
"W_uK = np.random.randn(d_2, d_3)\n",
"W_uV = np.random.randn(d_2, d_3)\n",
"\n",
"W_O = np.random.randn(d_3, d_1)\n",
"\n",
"C_KV = np.random.randn(existing_cache, d_2)\n",
"\n",
"# noob way\n",
"\n",
"# head_splitting\n",
"d_h = d_3 // h\n",
"\n",
"W_uQ = W_uQ.reshape(d_2, h, d_h).transpose(1, 0, 2)\n",
"W_uK = W_uK.reshape(d_2, h, d_h).transpose(1, 0, 2)\n",
"W_uV = W_uV.reshape(d_2, h, d_h).transpose(1, 0, 2)\n",
"W_O = W_O.reshape(h, d_h, d_1)\n",
"\n",
"# matrix absorption\n",
"W_aQK = W_uQ @ (W_uK.transpose(0, 2, 1))\n",
"W_aVO = W_uV @ W_O\n",
"\n",
"# update cache\n",
"new_CV = X@W_dKV\n",
"C_KV = np.vstack((C_KV, new_CV))\n",
"\n",
"# Attention part (left side)\n",
"C_Q = X @ W_dQ @ W_aQK\n",
"attn = C_Q @ C_KV.transpose(1, 0)\n",
"\n",
"# Value part (right side)\n",
"score = C_KV @ W_aVO\n",
"\n",
"# combine\n",
"out_heads = attn @ score\n",
"\n",
"# adding trick\n",
"final_out = out_heads.sum(axis=0)\n",
"\n",
"# --- VERIFICATION ---\n",
"print(f\"Input Shape: {X.shape}\")\n",
"print(f\"Latent Query: {C_Q.shape}\")\n",
"print(f\"Latent KV Cache: {C_KV.shape}\")\n",
"print(f\"Absorbed QK Kernel: {W_aQK.shape}\")\n",
"print(f\"Absorbed VO Kernel: {W_aVO.shape}\")\n",
"print(f\"Final Output Shape: {final_out.shape}\")\n",
"\n",
"# Compression\n",
"# ?"
]
}
]
}
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment