Created
September 25, 2017 06:35
-
-
Save anonymous/ce05e6233ddd461e127b320d7b768ca3 to your computer and use it in GitHub Desktop.
registering hooks in pytorch
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
| { | |
| "cells": [ | |
| { | |
| "metadata": {}, | |
| "cell_type": "markdown", | |
| "source": "demonstrating the hook registering.\n\nmostly copy-pasted from http://pytorch.org/tutorials/beginner/former_torchies/nn_tutorial.html" | |
| }, | |
| { | |
| "metadata": { | |
| "trusted": true, | |
| "collapsed": true | |
| }, | |
| "cell_type": "code", | |
| "source": "from torch import nn\nimport torch\nfrom torch.autograd import Variable", | |
| "execution_count": 1, | |
| "outputs": [] | |
| }, | |
| { | |
| "metadata": { | |
| "trusted": true | |
| }, | |
| "cell_type": "code", | |
| "source": "def printnorm(self, input, output):\n # input is a tuple of packed inputs\n # output is a Variable. output.data is the Tensor we are interested\n print('Inside ' + self.__class__.__name__ + ' forward')\n print('-')\n \n if not isinstance(input, tuple):\n input = (input,)\n for input_i in input:\n print('type(input_i): ', type(input_i))\n print('input_i.size(): ', input_i.size())\n \n print('-')\n\n if not isinstance(output, tuple):\n output = (output,)\n for output_i in output:\n print('output_i size:', output_i.data.size())\n print('output_i norm:', output_i.data.norm())\n print('==\\n')\n\n \ndef printgradnorm(self, grad_input, grad_output):\n print('Inside ' + self.__class__.__name__ + ' backward')\n print('Inside class:' + self.__class__.__name__)\n print('-')\n \n if not isinstance(grad_input, tuple):\n grad_input = (grad_input,)\n \n for grad_input_i in grad_input:\n print('type(grad_input_i): ', type(grad_input_i))\n if grad_input_i is not None:\n print('grad_input_i.size(): ', grad_input_i[0].size())\n \n print('-')\n \n if not isinstance(grad_output, tuple):\n grad_output = (grad_output,)\n\n \n for grad_output_i in grad_output:\n print('type(grad_output_i): ', type(grad_output_i))\n if grad_output_i is not None:\n print('grad_output_i.size(): ', grad_output_i[0].size())\n \n print('==\\n')\n \nfc1 = nn.Linear(in_features=10, out_features=20)\nfc2 = nn.Linear(in_features=20, out_features=30)\n \nfc1.register_backward_hook(printgradnorm)\nfc1.register_forward_hook(printnorm)\nfc2.register_backward_hook(printgradnorm)\nfc2.register_forward_hook(printnorm)\n\nprint(fc1)\nprint(fc2)\n\nx = Variable(torch.randn(8,10))\no1 = fc1(x)\no2 = fc2(o1)\nerr = o2.mean()\nerr.backward()", | |
| "execution_count": 2, | |
| "outputs": [ | |
| { | |
| "output_type": "stream", | |
| "text": "Linear (10 -> 20)\nLinear (20 -> 30)\nInside Linear forward\n-\ntype(input_i): <class 'torch.autograd.variable.Variable'>\ninput_i.size(): torch.Size([8, 10])\n-\noutput_i size: torch.Size([8, 20])\noutput_i norm: 6.87514330817611\n==\n\nInside Linear forward\n-\ntype(input_i): <class 'torch.autograd.variable.Variable'>\ninput_i.size(): torch.Size([8, 20])\n-\noutput_i size: torch.Size([8, 30])\noutput_i norm: 5.230063070500036\n==\n\nInside Linear backward\nInside class:Linear\n-\ntype(grad_input_i): <class 'torch.autograd.variable.Variable'>\ngrad_input_i.size(): torch.Size([1])\ntype(grad_input_i): <class 'torch.autograd.variable.Variable'>\ngrad_input_i.size(): torch.Size([20])\ntype(grad_input_i): <class 'torch.autograd.variable.Variable'>\ngrad_input_i.size(): torch.Size([30])\n-\ntype(grad_output_i): <class 'torch.autograd.variable.Variable'>\ngrad_output_i.size(): torch.Size([30])\n==\n\nInside Linear backward\nInside class:Linear\n-\ntype(grad_input_i): <class 'torch.autograd.variable.Variable'>\ngrad_input_i.size(): torch.Size([1])\ntype(grad_input_i): <class 'NoneType'>\ntype(grad_input_i): <class 'torch.autograd.variable.Variable'>\ngrad_input_i.size(): torch.Size([20])\n-\ntype(grad_output_i): <class 'torch.autograd.variable.Variable'>\ngrad_output_i.size(): torch.Size([20])\n==\n\n", | |
| "name": "stdout" | |
| } | |
| ] | |
| }, | |
| { | |
| "metadata": { | |
| "trusted": true | |
| }, | |
| "cell_type": "code", | |
| "source": "## the hook registry is not recursive.. e.g. in the case that you have a Module with submodules.., \nfrom collections import OrderedDict\nnet = nn.Sequential(OrderedDict({\n 'fc1': nn.Linear(in_features=10, out_features=20),\n 'fc2': nn.Linear(in_features=20, out_features=30)}))\n \nnet.register_forward_hook(printnorm)\nnet.register_backward_hook(printgradnorm)\n\n\nx = Variable(torch.randn(8,10))\ny = net(x)\nerr = y.mean()\nerr.backward()", | |
| "execution_count": 3, | |
| "outputs": [ | |
| { | |
| "output_type": "stream", | |
| "text": "Inside Sequential forward\n-\ntype(input_i): <class 'torch.autograd.variable.Variable'>\ninput_i.size(): torch.Size([8, 10])\n-\noutput_i size: torch.Size([8, 30])\noutput_i norm: 5.208447818216277\n==\n\nInside Sequential backward\nInside class:Sequential\n-\ntype(grad_input_i): <class 'torch.autograd.variable.Variable'>\ngrad_input_i.size(): torch.Size([1])\ntype(grad_input_i): <class 'torch.autograd.variable.Variable'>\ngrad_input_i.size(): torch.Size([20])\ntype(grad_input_i): <class 'torch.autograd.variable.Variable'>\ngrad_input_i.size(): torch.Size([30])\n-\ntype(grad_output_i): <class 'torch.autograd.variable.Variable'>\ngrad_output_i.size(): torch.Size([30])\n==\n\n", | |
| "name": "stdout" | |
| } | |
| ] | |
| }, | |
| { | |
| "metadata": { | |
| "trusted": true, | |
| "collapsed": true | |
| }, | |
| "cell_type": "code", | |
| "source": "", | |
| "execution_count": null, | |
| "outputs": [] | |
| } | |
| ], | |
| "metadata": { | |
| "kernelspec": { | |
| "name": "t3", | |
| "display_name": "t3", | |
| "language": "python" | |
| }, | |
| "language_info": { | |
| "name": "python", | |
| "file_extension": ".py", | |
| "nbconvert_exporter": "python", | |
| "codemirror_mode": { | |
| "version": 3, | |
| "name": "ipython" | |
| }, | |
| "version": "3.6.2", | |
| "pygments_lexer": "ipython3", | |
| "mimetype": "text/x-python" | |
| }, | |
| "gist": { | |
| "id": "53ae672f0de85da6db3e69533a00c0d3", | |
| "data": { | |
| "description": "registering hooks in pytorch", | |
| "public": true | |
| } | |
| }, | |
| "_draft": { | |
| "nbviewer_url": "https://gist.github.com/53ae672f0de85da6db3e69533a00c0d3" | |
| } | |
| }, | |
| "nbformat": 4, | |
| "nbformat_minor": 2 | |
| } |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment