Created
March 2, 2018 15:42
-
-
Save izmailovpavel/65afa6212f21fe752e48056d8d723f9d to your computer and use it in GitHub Desktop.
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
| { | |
| "cells": [ | |
| { | |
| "cell_type": "code", | |
| "execution_count": 12, | |
| "metadata": {}, | |
| "outputs": [], | |
| "source": [ | |
| "import numpy as np\n", | |
| "import torch\n", | |
| "from torch import nn as nn\n", | |
| "from torch.autograd import Variable" | |
| ] | |
| }, | |
| { | |
| "cell_type": "code", | |
| "execution_count": 63, | |
| "metadata": {}, | |
| "outputs": [ | |
| { | |
| "data": { | |
| "text/plain": [ | |
| "<torch.utils.hooks.RemovableHandle at 0x7f23101cdc18>" | |
| ] | |
| }, | |
| "execution_count": 63, | |
| "metadata": {}, | |
| "output_type": "execute_result" | |
| } | |
| ], | |
| "source": [ | |
| "grads = {}\n", | |
| "def save_grad(name):\n", | |
| " def hook(grad):\n", | |
| " grads[name] = grad\n", | |
| " return hook\n", | |
| "\n", | |
| "def extract_grad(var):\n", | |
| " print(var)\n", | |
| " print(var.shape)\n", | |
| " return var\n", | |
| "\n", | |
| "n_feat = 10\n", | |
| "n_obj = 25\n", | |
| "X = np.random.normal(size=(n_obj, n_feat))\n", | |
| "y = np.random.randint(low=0, high=10, size=(n_obj))\n", | |
| "X_ = Variable(torch.from_numpy(X), requires_grad=True)\n", | |
| "y_ = Variable(torch.from_numpy(y))\n", | |
| "lsm = nn.LogSoftmax(dim=1)(X_)\n", | |
| "l = nn.NLLLoss()(lsm, y_)\n", | |
| "\n", | |
| "lsm.register_hook(save_grad('lsm'))" | |
| ] | |
| }, | |
| { | |
| "cell_type": "code", | |
| "execution_count": 64, | |
| "metadata": {}, | |
| "outputs": [], | |
| "source": [ | |
| "l.backward()" | |
| ] | |
| }, | |
| { | |
| "cell_type": "code", | |
| "execution_count": 65, | |
| "metadata": {}, | |
| "outputs": [ | |
| { | |
| "data": { | |
| "text/plain": [ | |
| "Variable containing:\n", | |
| "1.00000e-02 *\n", | |
| " 0.0000 0.0000 -4.0000 0.0000 0.0000 0.0000 0.0000 0.0000 0.0000 0.0000\n", | |
| " 0.0000 0.0000 0.0000 0.0000 0.0000 0.0000 0.0000 0.0000 -4.0000 0.0000\n", | |
| " 0.0000 0.0000 0.0000 0.0000 0.0000 0.0000 0.0000 0.0000 0.0000 -4.0000\n", | |
| " 0.0000 0.0000 -4.0000 0.0000 0.0000 0.0000 0.0000 0.0000 0.0000 0.0000\n", | |
| " 0.0000 0.0000 -4.0000 0.0000 0.0000 0.0000 0.0000 0.0000 0.0000 0.0000\n", | |
| " 0.0000 -4.0000 0.0000 0.0000 0.0000 0.0000 0.0000 0.0000 0.0000 0.0000\n", | |
| " 0.0000 0.0000 0.0000 0.0000 -4.0000 0.0000 0.0000 0.0000 0.0000 0.0000\n", | |
| " 0.0000 0.0000 0.0000 -4.0000 0.0000 0.0000 0.0000 0.0000 0.0000 0.0000\n", | |
| " 0.0000 -4.0000 0.0000 0.0000 0.0000 0.0000 0.0000 0.0000 0.0000 0.0000\n", | |
| " 0.0000 0.0000 0.0000 0.0000 0.0000 0.0000 0.0000 -4.0000 0.0000 0.0000\n", | |
| " 0.0000 0.0000 0.0000 0.0000 -4.0000 0.0000 0.0000 0.0000 0.0000 0.0000\n", | |
| " 0.0000 0.0000 0.0000 0.0000 0.0000 0.0000 -4.0000 0.0000 0.0000 0.0000\n", | |
| " 0.0000 0.0000 0.0000 0.0000 0.0000 0.0000 0.0000 0.0000 -4.0000 0.0000\n", | |
| " 0.0000 0.0000 0.0000 0.0000 0.0000 -4.0000 0.0000 0.0000 0.0000 0.0000\n", | |
| " 0.0000 0.0000 0.0000 0.0000 0.0000 -4.0000 0.0000 0.0000 0.0000 0.0000\n", | |
| " 0.0000 0.0000 0.0000 0.0000 0.0000 -4.0000 0.0000 0.0000 0.0000 0.0000\n", | |
| " 0.0000 0.0000 0.0000 0.0000 0.0000 -4.0000 0.0000 0.0000 0.0000 0.0000\n", | |
| " 0.0000 0.0000 -4.0000 0.0000 0.0000 0.0000 0.0000 0.0000 0.0000 0.0000\n", | |
| " 0.0000 -4.0000 0.0000 0.0000 0.0000 0.0000 0.0000 0.0000 0.0000 0.0000\n", | |
| " 0.0000 0.0000 -4.0000 0.0000 0.0000 0.0000 0.0000 0.0000 0.0000 0.0000\n", | |
| " 0.0000 0.0000 0.0000 0.0000 0.0000 -4.0000 0.0000 0.0000 0.0000 0.0000\n", | |
| " 0.0000 -4.0000 0.0000 0.0000 0.0000 0.0000 0.0000 0.0000 0.0000 0.0000\n", | |
| " 0.0000 0.0000 0.0000 -4.0000 0.0000 0.0000 0.0000 0.0000 0.0000 0.0000\n", | |
| " 0.0000 0.0000 0.0000 -4.0000 0.0000 0.0000 0.0000 0.0000 0.0000 0.0000\n", | |
| " -4.0000 0.0000 0.0000 0.0000 0.0000 0.0000 0.0000 0.0000 0.0000 0.0000\n", | |
| "[torch.DoubleTensor of size 25x10]" | |
| ] | |
| }, | |
| "execution_count": 65, | |
| "metadata": {}, | |
| "output_type": "execute_result" | |
| } | |
| ], | |
| "source": [ | |
| "grads['lsm']" | |
| ] | |
| }, | |
| { | |
| "cell_type": "code", | |
| "execution_count": 66, | |
| "metadata": {}, | |
| "outputs": [ | |
| { | |
| "data": { | |
| "text/plain": [ | |
| "Variable containing:\n", | |
| "1.00000e-02 *\n", | |
| " 0.8155 0.2794 -3.8968 0.5208 0.1211 0.1943 0.5007 0.3851 0.1434 0.9365\n", | |
| " 0.5790 0.1383 0.5933 0.6330 0.4394 0.6952 0.1665 0.0873 -3.5001 0.1682\n", | |
| " 0.5062 0.0578 0.0737 1.0623 1.3683 0.1590 0.1494 0.5018 0.0833 -3.9618\n", | |
| " 0.1304 1.4510 -3.7681 0.0102 0.1075 0.6720 0.5830 0.4104 0.2180 0.1856\n", | |
| " 0.1847 0.4890 -3.7672 0.2188 0.9282 0.1364 0.0864 0.4260 0.4284 0.8693\n", | |
| " 0.0223 -3.5900 1.6741 0.1133 0.0790 0.3652 0.1131 0.0791 0.7763 0.3677\n", | |
| " 0.1598 0.1576 0.2000 0.1278 -3.8787 0.7640 0.4842 0.5352 1.2654 0.1847\n", | |
| " 0.6389 0.1485 0.1226 -3.6593 0.1309 0.1685 0.3405 1.6761 0.1646 0.2687\n", | |
| " 0.6277 -3.3878 0.2061 0.1202 1.0843 0.2062 0.4640 0.0806 0.4340 0.1647\n", | |
| " 0.6289 0.2413 0.0680 0.2713 0.2365 0.8447 0.0867 -2.9684 0.2418 0.3492\n", | |
| " 0.6839 0.0435 0.2797 0.1037 -3.7896 0.5422 0.6385 0.8123 0.0335 0.6521\n", | |
| " 0.2857 0.2350 0.7899 0.2513 0.8590 0.0355 -3.4502 0.2588 0.2522 0.4828\n", | |
| " 0.2759 0.1700 0.1678 0.2723 0.0284 0.0710 1.5839 0.7363 -3.4987 0.1932\n", | |
| " 0.3499 0.9402 0.4786 0.2367 0.7633 -3.6291 0.3840 0.0831 0.2335 0.1597\n", | |
| " 0.2092 0.1013 0.7504 0.2398 0.0702 -3.7920 0.2251 0.3815 1.5150 0.2994\n", | |
| " 0.1143 0.0623 0.2665 0.3133 0.5581 -3.8609 0.9103 0.4654 0.4638 0.7068\n", | |
| " 0.1304 0.2630 0.3659 1.7599 0.1529 -3.8142 0.3330 0.4531 0.2820 0.0741\n", | |
| " 0.0638 0.5095 -3.7011 0.2375 0.0971 0.0885 0.1205 0.0939 2.4365 0.0538\n", | |
| " 0.6268 -3.7864 0.1416 1.0899 0.7207 0.3281 0.0347 0.2979 0.2298 0.3171\n", | |
| " 0.4224 0.7358 -3.3912 0.2569 0.2338 0.3163 0.3436 0.4604 0.3109 0.3111\n", | |
| " 0.5334 0.0474 0.3382 0.8208 0.4691 -3.7323 0.4889 0.1613 0.4568 0.4165\n", | |
| " 0.4489 -3.5636 0.4416 0.0672 0.0471 0.8899 0.7920 0.6428 0.1531 0.0810\n", | |
| " 0.1111 0.1431 0.1120 -2.9565 0.6451 0.9736 0.0359 0.5117 0.2584 0.1655\n", | |
| " 0.3391 0.9316 0.2202 -3.4725 0.0128 0.1691 0.2743 0.4158 0.8600 0.2497\n", | |
| " -3.8255 0.6958 0.3361 0.3637 0.0943 0.3866 0.4016 0.6418 0.2407 0.6649\n", | |
| "[torch.DoubleTensor of size 25x10]" | |
| ] | |
| }, | |
| "execution_count": 66, | |
| "metadata": {}, | |
| "output_type": "execute_result" | |
| } | |
| ], | |
| "source": [ | |
| "X_.grad" | |
| ] | |
| }, | |
| { | |
| "cell_type": "code", | |
| "execution_count": 56, | |
| "metadata": {}, | |
| "outputs": [], | |
| "source": [ | |
| "lsm.grad" | |
| ] | |
| }, | |
| { | |
| "cell_type": "code", | |
| "execution_count": 72, | |
| "metadata": {}, | |
| "outputs": [], | |
| "source": [ | |
| "dlsm_dXs = []\n", | |
| "for i in range(n_obj):\n", | |
| " denom = np.sum(np.exp(X[i]))\n", | |
| " dlsm_dXs.append(np.eye(n_feat) - np.exp(X[i][:, None]) / denom)\n", | |
| "dlsm_dX = np.array(dlsm_dXs)" | |
| ] | |
| }, | |
| { | |
| "cell_type": "code", | |
| "execution_count": 74, | |
| "metadata": {}, | |
| "outputs": [], | |
| "source": [ | |
| "dlsm_dX_ = Variable(torch.from_numpy(dlsm_dX))" | |
| ] | |
| }, | |
| { | |
| "cell_type": "code", | |
| "execution_count": 80, | |
| "metadata": {}, | |
| "outputs": [], | |
| "source": [ | |
| "ans = np.einsum('ijk, ik -> ij', dlsm_dX, grads['lsm'].data.numpy())" | |
| ] | |
| }, | |
| { | |
| "cell_type": "code", | |
| "execution_count": 84, | |
| "metadata": {}, | |
| "outputs": [ | |
| { | |
| "data": { | |
| "text/plain": [ | |
| "2.5407775362056137e-17" | |
| ] | |
| }, | |
| "execution_count": 84, | |
| "metadata": {}, | |
| "output_type": "execute_result" | |
| } | |
| ], | |
| "source": [ | |
| "np.linalg.norm(ans - X_.grad.data.numpy())" | |
| ] | |
| }, | |
| { | |
| "cell_type": "code", | |
| "execution_count": null, | |
| "metadata": {}, | |
| "outputs": [], | |
| "source": [] | |
| } | |
| ], | |
| "metadata": { | |
| "kernelspec": { | |
| "display_name": "Python 3", | |
| "language": "python", | |
| "name": "python3" | |
| }, | |
| "language_info": { | |
| "codemirror_mode": { | |
| "name": "ipython", | |
| "version": 3 | |
| }, | |
| "file_extension": ".py", | |
| "mimetype": "text/x-python", | |
| "name": "python", | |
| "nbconvert_exporter": "python", | |
| "pygments_lexer": "ipython3", | |
| "version": "3.5.2" | |
| } | |
| }, | |
| "nbformat": 4, | |
| "nbformat_minor": 2 | |
| } |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment