Skip to content

Instantly share code, notes, and snippets.

Created September 25, 2017 06:35
Show Gist options
  • Select an option

  • Save anonymous/ce05e6233ddd461e127b320d7b768ca3 to your computer and use it in GitHub Desktop.

Select an option

Save anonymous/ce05e6233ddd461e127b320d7b768ca3 to your computer and use it in GitHub Desktop.
registering hooks in pytorch
Display the source blob
Display the rendered blob
Raw
{
"cells": [
{
"metadata": {},
"cell_type": "markdown",
"source": "demonstrating the hook registering.\n\nmostly copy-pasted from http://pytorch.org/tutorials/beginner/former_torchies/nn_tutorial.html"
},
{
"metadata": {
"trusted": true,
"collapsed": true
},
"cell_type": "code",
"source": "from torch import nn\nimport torch\nfrom torch.autograd import Variable",
"execution_count": 1,
"outputs": []
},
{
"metadata": {
"trusted": true
},
"cell_type": "code",
"source": "def printnorm(self, input, output):\n # input is a tuple of packed inputs\n # output is a Variable. output.data is the Tensor we are interested\n print('Inside ' + self.__class__.__name__ + ' forward')\n print('-')\n \n if not isinstance(input, tuple):\n input = (input,)\n for input_i in input:\n print('type(input_i): ', type(input_i))\n print('input_i.size(): ', input_i.size())\n \n print('-')\n\n if not isinstance(output, tuple):\n output = (output,)\n for output_i in output:\n print('output_i size:', output_i.data.size())\n print('output_i norm:', output_i.data.norm())\n print('==\\n')\n\n \ndef printgradnorm(self, grad_input, grad_output):\n print('Inside ' + self.__class__.__name__ + ' backward')\n print('Inside class:' + self.__class__.__name__)\n print('-')\n \n if not isinstance(grad_input, tuple):\n grad_input = (grad_input,)\n \n for grad_input_i in grad_input:\n print('type(grad_input_i): ', type(grad_input_i))\n if grad_input_i is not None:\n print('grad_input_i.size(): ', grad_input_i[0].size())\n \n print('-')\n \n if not isinstance(grad_output, tuple):\n grad_output = (grad_output,)\n\n \n for grad_output_i in grad_output:\n print('type(grad_output_i): ', type(grad_output_i))\n if grad_output_i is not None:\n print('grad_output_i.size(): ', grad_output_i[0].size())\n \n print('==\\n')\n \nfc1 = nn.Linear(in_features=10, out_features=20)\nfc2 = nn.Linear(in_features=20, out_features=30)\n \nfc1.register_backward_hook(printgradnorm)\nfc1.register_forward_hook(printnorm)\nfc2.register_backward_hook(printgradnorm)\nfc2.register_forward_hook(printnorm)\n\nprint(fc1)\nprint(fc2)\n\nx = Variable(torch.randn(8,10))\no1 = fc1(x)\no2 = fc2(o1)\nerr = o2.mean()\nerr.backward()",
"execution_count": 2,
"outputs": [
{
"output_type": "stream",
"text": "Linear (10 -> 20)\nLinear (20 -> 30)\nInside Linear forward\n-\ntype(input_i): <class 'torch.autograd.variable.Variable'>\ninput_i.size(): torch.Size([8, 10])\n-\noutput_i size: torch.Size([8, 20])\noutput_i norm: 6.87514330817611\n==\n\nInside Linear forward\n-\ntype(input_i): <class 'torch.autograd.variable.Variable'>\ninput_i.size(): torch.Size([8, 20])\n-\noutput_i size: torch.Size([8, 30])\noutput_i norm: 5.230063070500036\n==\n\nInside Linear backward\nInside class:Linear\n-\ntype(grad_input_i): <class 'torch.autograd.variable.Variable'>\ngrad_input_i.size(): torch.Size([1])\ntype(grad_input_i): <class 'torch.autograd.variable.Variable'>\ngrad_input_i.size(): torch.Size([20])\ntype(grad_input_i): <class 'torch.autograd.variable.Variable'>\ngrad_input_i.size(): torch.Size([30])\n-\ntype(grad_output_i): <class 'torch.autograd.variable.Variable'>\ngrad_output_i.size(): torch.Size([30])\n==\n\nInside Linear backward\nInside class:Linear\n-\ntype(grad_input_i): <class 'torch.autograd.variable.Variable'>\ngrad_input_i.size(): torch.Size([1])\ntype(grad_input_i): <class 'NoneType'>\ntype(grad_input_i): <class 'torch.autograd.variable.Variable'>\ngrad_input_i.size(): torch.Size([20])\n-\ntype(grad_output_i): <class 'torch.autograd.variable.Variable'>\ngrad_output_i.size(): torch.Size([20])\n==\n\n",
"name": "stdout"
}
]
},
{
"metadata": {
"trusted": true
},
"cell_type": "code",
"source": "## the hook registry is not recursive.. e.g. in the case that you have a Module with submodules.., \nfrom collections import OrderedDict\nnet = nn.Sequential(OrderedDict({\n 'fc1': nn.Linear(in_features=10, out_features=20),\n 'fc2': nn.Linear(in_features=20, out_features=30)}))\n \nnet.register_forward_hook(printnorm)\nnet.register_backward_hook(printgradnorm)\n\n\nx = Variable(torch.randn(8,10))\ny = net(x)\nerr = y.mean()\nerr.backward()",
"execution_count": 3,
"outputs": [
{
"output_type": "stream",
"text": "Inside Sequential forward\n-\ntype(input_i): <class 'torch.autograd.variable.Variable'>\ninput_i.size(): torch.Size([8, 10])\n-\noutput_i size: torch.Size([8, 30])\noutput_i norm: 5.208447818216277\n==\n\nInside Sequential backward\nInside class:Sequential\n-\ntype(grad_input_i): <class 'torch.autograd.variable.Variable'>\ngrad_input_i.size(): torch.Size([1])\ntype(grad_input_i): <class 'torch.autograd.variable.Variable'>\ngrad_input_i.size(): torch.Size([20])\ntype(grad_input_i): <class 'torch.autograd.variable.Variable'>\ngrad_input_i.size(): torch.Size([30])\n-\ntype(grad_output_i): <class 'torch.autograd.variable.Variable'>\ngrad_output_i.size(): torch.Size([30])\n==\n\n",
"name": "stdout"
}
]
},
{
"metadata": {
"trusted": true,
"collapsed": true
},
"cell_type": "code",
"source": "",
"execution_count": null,
"outputs": []
}
],
"metadata": {
"kernelspec": {
"name": "t3",
"display_name": "t3",
"language": "python"
},
"language_info": {
"name": "python",
"file_extension": ".py",
"nbconvert_exporter": "python",
"codemirror_mode": {
"version": 3,
"name": "ipython"
},
"version": "3.6.2",
"pygments_lexer": "ipython3",
"mimetype": "text/x-python"
},
"gist": {
"id": "53ae672f0de85da6db3e69533a00c0d3",
"data": {
"description": "registering hooks in pytorch",
"public": true
}
},
"_draft": {
"nbviewer_url": "https://gist.github.com/53ae672f0de85da6db3e69533a00c0d3"
}
},
"nbformat": 4,
"nbformat_minor": 2
}
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment