-
-
Save thomasjpfan/86732a6d32acb270212046b3f3a9dbd3 to your computer and use it in GitHub Desktop.
Weight-initialization-methods
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
| { | |
| "cells": [ | |
| { | |
| "cell_type": "code", | |
| "execution_count": 1, | |
| "metadata": {}, | |
| "outputs": [], | |
| "source": [ | |
| "from comet_ml import Experiment\n", | |
| "import torch\n", | |
| "import torch.nn as nn\n", | |
| "import torchvision.transforms as transforms\n", | |
| "import torchvision.datasets as dsets" | |
| ] | |
| }, | |
| { | |
| "cell_type": "code", | |
| "execution_count": 2, | |
| "metadata": {}, | |
| "outputs": [ | |
| { | |
| "name": "stderr", | |
| "output_type": "stream", | |
| "text": [ | |
| "COMET INFO: Experiment is live on comet.ml https://www.comet.ml/ceceshao1/weight-initialization/b8fec6d781a04706b94d85499e09ec95\n", | |
| "\n" | |
| ] | |
| } | |
| ], | |
| "source": [ | |
| "experiment = Experiment(api_key=\"YOUR_API_KEY\",\n", | |
| " project_name=\"YOUR_PROJECT_NAME\", workspace=\"YOUR_WORKSPACE_NAME\")" | |
| ] | |
| }, | |
| { | |
| "cell_type": "code", | |
| "execution_count": 2, | |
| "metadata": {}, | |
| "outputs": [], | |
| "source": [ | |
| "# Set seed\n", | |
| "random_seed = torch.manual_seed(19)\n", | |
| "experiment.log_other(random_seed, 19)" | |
| ] | |
| }, | |
| { | |
| "cell_type": "code", | |
| "execution_count": 3, | |
| "metadata": {}, | |
| "outputs": [], | |
| "source": [ | |
| "# Scheduler import\n", | |
| "from torch.optim.lr_scheduler import StepLR" | |
| ] | |
| }, | |
| { | |
| "cell_type": "code", | |
| "execution_count": 5, | |
| "metadata": {}, | |
| "outputs": [ | |
| { | |
| "data": { | |
| "text/html": [ | |
| "\n", | |
| " <iframe\n", | |
| " width=\"100%\"\n", | |
| " height=\"800px\"\n", | |
| " src=\"https://www.comet.ml/ceceshao1/weight-initialization/b8fec6d781a04706b94d85499e09ec95\"\n", | |
| " frameborder=\"0\"\n", | |
| " allowfullscreen\n", | |
| " ></iframe>\n", | |
| " " | |
| ], | |
| "text/plain": [ | |
| "<IPython.lib.display.IFrame at 0x104e87f60>" | |
| ] | |
| }, | |
| "metadata": {}, | |
| "output_type": "display_data" | |
| } | |
| ], | |
| "source": [ | |
| "experiment.display()" | |
| ] | |
| }, | |
| { | |
| "cell_type": "markdown", | |
| "metadata": {}, | |
| "source": [ | |
| "### Loading Dataset " | |
| ] | |
| }, | |
| { | |
| "cell_type": "code", | |
| "execution_count": 63, | |
| "metadata": {}, | |
| "outputs": [], | |
| "source": [ | |
| "train_dataset = dsets.MNIST(root='./data', \n", | |
| " train=True, \n", | |
| " transform=transforms.ToTensor(),\n", | |
| " download=True)\n", | |
| "\n", | |
| "test_dataset = dsets.MNIST(root='./data', \n", | |
| " train=False, \n", | |
| " transform=transforms.ToTensor())" | |
| ] | |
| }, | |
| { | |
| "cell_type": "code", | |
| "execution_count": 7, | |
| "metadata": {}, | |
| "outputs": [], | |
| "source": [ | |
| "experiment.log_dataset_hash(train_dataset)" | |
| ] | |
| }, | |
| { | |
| "cell_type": "markdown", | |
| "metadata": {}, | |
| "source": [ | |
| "### Make Dataset Iterable (using data loaders)" | |
| ] | |
| }, | |
| { | |
| "cell_type": "code", | |
| "execution_count": 6, | |
| "metadata": {}, | |
| "outputs": [], | |
| "source": [ | |
| "#set parameters \n", | |
| "\n", | |
| "batch_size = 100\n", | |
| "n_iters = 3000\n", | |
| "num_epochs = n_iters / (len(train_dataset) / batch_size)\n", | |
| "num_epochs = int(num_epochs)\n", | |
| "learning_rate = 0.1\n", | |
| "\n", | |
| "params = {\n", | |
| " \"batch_size\": batch_size,\n", | |
| " \"n_iters\": n_iters,\n", | |
| " \"num_epochs\": num_epochs,\n", | |
| " \"learning_rate\": learning_rate\n", | |
| "}\n", | |
| "\n", | |
| "experiment.log_parameters(params)" | |
| ] | |
| }, | |
| { | |
| "cell_type": "code", | |
| "execution_count": 25, | |
| "metadata": {}, | |
| "outputs": [], | |
| "source": [ | |
| "train_loader = torch.utils.data.DataLoader(dataset=train_dataset, \n", | |
| " batch_size=batch_size, \n", | |
| " shuffle=True)\n", | |
| "\n", | |
| "test_loader = torch.utils.data.DataLoader(dataset=test_dataset, \n", | |
| " batch_size=batch_size, \n", | |
| " shuffle=False)" | |
| ] | |
| }, | |
| { | |
| "cell_type": "markdown", | |
| "metadata": {}, | |
| "source": [ | |
| "### Create the Model Class \n", | |
| "\n", | |
| "> Note: Depending on whether you'd like to use the tanh activation or ReLU activation, you should only run one of the following two cells\n", | |
| "\n" | |
| ] | |
| }, | |
| { | |
| "cell_type": "markdown", | |
| "metadata": {}, | |
| "source": [ | |
| "#### OPTION 1:\n", | |
| "Using tanh activation -```nn.Tanh()``` and:\n", | |
| "- normal weight initialization\n", | |
| "- lecunn weight initialization\n", | |
| "- xavier weight initialization\n", | |
| "\n", | |
| "#### OPTION 2:\n", | |
| "Using ReLU activation -```nn.ReLU()``` and:\n", | |
| "- xavier weight initialization\n", | |
| "- he weight initialization" | |
| ] | |
| }, | |
| { | |
| "cell_type": "code", | |
| "execution_count": 56, | |
| "metadata": {}, | |
| "outputs": [], | |
| "source": [ | |
| "def initalize_weights(weights, init_type):\n", | |
| " if init_type == \"normal\":\n", | |
| " nn.init.normal_(weights, mean=0, std=1)\n", | |
| " elif init_type == \"xavier\":\n", | |
| " nn.init.xavier_normal_(weights)\n", | |
| " elif init_type == \"he\":\n", | |
| " nn.init.kaiming_normal_(weights)\n", | |
| " # default is lecun (no need to do anything)\n", | |
| "\n", | |
| "class FeedforwardNeuralNetModel(nn.Module):\n", | |
| " def __init__(self, input_dim, hidden_dim, output_dim, init_type, activation):\n", | |
| " super().__init__()\n", | |
| " self.fc1 = nn.Linear(input_dim, hidden_dim)\n", | |
| " if activation == 'relu':\n", | |
| " self.activation = nn.ReLU()\n", | |
| " else:\n", | |
| " self.activation = nn.Tanh()\n", | |
| " self.fc2 = nn.Linear(hidden_dim, output_dim)\n", | |
| " \n", | |
| " initalize_weights(self.fc1.weight, init_type)\n", | |
| " initalize_weights(self.fc2.weight, init_type)\n", | |
| " \n", | |
| " def forward(self, x):\n", | |
| " out = self.fc1(x)\n", | |
| " out = self.activation(out)\n", | |
| " return self.fc2(out)" | |
| ] | |
| }, | |
| { | |
| "cell_type": "markdown", | |
| "metadata": {}, | |
| "source": [ | |
| "### Instantiate Model Class\n", | |
| "\n", | |
| "Change the following to test different configurations:\n", | |
| "\n", | |
| "- `init_type` can be `normal`, `xavier`, `he`, or `lecun`\n", | |
| "- `activation` can be `relu` or `tanh`" | |
| ] | |
| }, | |
| { | |
| "cell_type": "code", | |
| "execution_count": 58, | |
| "metadata": {}, | |
| "outputs": [], | |
| "source": [ | |
| "input_dim = 28*28\n", | |
| "hidden_dim = 100\n", | |
| "output_dim = 10\n", | |
| "\n", | |
| "init_type = 'lecun'\n", | |
| "activation = 'tanh'\n", | |
| "\n", | |
| "model = FeedforwardNeuralNetModel(input_dim, hidden_dim, output_dim, init_type, activation)\n", | |
| "experiment.log_other('activation', activation)\n", | |
| "experiment.log_other(\"initialization\", init_type)" | |
| ] | |
| }, | |
| { | |
| "cell_type": "markdown", | |
| "metadata": {}, | |
| "source": [ | |
| "### Define Loss Class " | |
| ] | |
| }, | |
| { | |
| "cell_type": "code", | |
| "execution_count": 59, | |
| "metadata": {}, | |
| "outputs": [], | |
| "source": [ | |
| "criterion = nn.CrossEntropyLoss()" | |
| ] | |
| }, | |
| { | |
| "cell_type": "markdown", | |
| "metadata": {}, | |
| "source": [ | |
| "### Define Optimizer Class " | |
| ] | |
| }, | |
| { | |
| "cell_type": "code", | |
| "execution_count": 60, | |
| "metadata": {}, | |
| "outputs": [], | |
| "source": [ | |
| "optimizer = torch.optim.SGD(model.parameters(), lr=learning_rate, momentum=0.9, nesterov=True)" | |
| ] | |
| }, | |
| { | |
| "cell_type": "markdown", | |
| "metadata": {}, | |
| "source": [ | |
| "### Instantiate Step Learning Scheduler Class" | |
| ] | |
| }, | |
| { | |
| "cell_type": "code", | |
| "execution_count": 61, | |
| "metadata": {}, | |
| "outputs": [], | |
| "source": [ | |
| "# step_size: at how many multiples of epoch you decay\n", | |
| "# step_size = 1, after every 2 epoch, new_lr = lr*gamma \n", | |
| "# step_size = 2, after every 2 epoch, new_lr = lr*gamma \n", | |
| "# gamma = decaying factor\n", | |
| "\n", | |
| "scheduler = StepLR(optimizer, step_size=1, gamma=0.96)\n", | |
| "experiment.log_parameter(\"gamma\", 0.96)" | |
| ] | |
| }, | |
| { | |
| "cell_type": "markdown", | |
| "metadata": {}, | |
| "source": [ | |
| "### Train the Model" | |
| ] | |
| }, | |
| { | |
| "cell_type": "code", | |
| "execution_count": 64, | |
| "metadata": {}, | |
| "outputs": [ | |
| { | |
| "name": "stdout", | |
| "output_type": "stream", | |
| "text": [ | |
| "Epoch: 0 LR: [0.08847359999999999]\n", | |
| "Iteration: 1200. Loss: 0.03941712900996208. Accuracy: 97.08\n" | |
| ] | |
| } | |
| ], | |
| "source": [ | |
| "batches = 0\n", | |
| "for epoch in range(num_epochs):\n", | |
| " # Print Learning Rate\n", | |
| " print('Epoch:', epoch,'LR:', scheduler.get_lr())\n", | |
| " for i, (images, labels) in enumerate(train_loader):\n", | |
| " # Load images as tensors with gradient accumulation abilities\n", | |
| " images = images.view(-1, 28*28)\n", | |
| "\n", | |
| " # Clear gradients w.r.t. parameters\n", | |
| " optimizer.zero_grad()\n", | |
| "\n", | |
| " # Forward pass to get output/logits\n", | |
| " outputs = model(images)\n", | |
| "\n", | |
| " # Calculate Loss: softmax --> cross entropy loss\n", | |
| " loss = criterion(outputs, labels)\n", | |
| " experiment.log_metric(\"loss\", loss)\n", | |
| "\n", | |
| " # Getting gradients w.r.t. parameters\n", | |
| " loss.backward()\n", | |
| "\n", | |
| " # Updating parameters\n", | |
| " optimizer.step()\n", | |
| "\n", | |
| " batches += 1\n", | |
| "\n", | |
| " if batches % 500 == 0:\n", | |
| " # Calculate Accuracy \n", | |
| " correct = 0\n", | |
| " total = 0\n", | |
| " # Iterate through test dataset\n", | |
| " for images, labels in test_loader:\n", | |
| " # Load images\n", | |
| " images = images.view(-1, 28*28)\n", | |
| "\n", | |
| " # Forward pass only to get logits/output\n", | |
| " outputs = model(images)\n", | |
| "\n", | |
| " # Get predictions from the maximum value\n", | |
| " _, predicted = torch.max(outputs, 1)\n", | |
| "\n", | |
| " # Total number of labels\n", | |
| " total += labels.size(0)\n", | |
| "\n", | |
| " # Total correct predictions\n", | |
| " correct += (predicted.type(torch.FloatTensor).cpu() == labels.type(torch.FloatTensor)).sum()\n", | |
| "\n", | |
| " accuracy = 100. * correct.item() / total\n", | |
| " experiment.log_metric(\"accuracy\", accuracy)\n", | |
| " \n", | |
| " # Print Loss\n", | |
| " print('Iteration: {}. Loss: {}. Accuracy: {}'.format(iter, loss.item(), accuracy))\n", | |
| " # Decay Learning Rate\n", | |
| " scheduler.step()" | |
| ] | |
| }, | |
| { | |
| "cell_type": "code", | |
| "execution_count": 18, | |
| "metadata": {}, | |
| "outputs": [ | |
| { | |
| "name": "stderr", | |
| "output_type": "stream", | |
| "text": [ | |
| "COMET INFO: ----------------------------\n", | |
| "COMET INFO: Comet.ml Experiment Summary:\n", | |
| "COMET INFO: Data:\n", | |
| "COMET INFO: url: https://www.comet.ml/ceceshao1/weight-initialization/b8fec6d781a04706b94d85499e09ec95\n", | |
| "COMET INFO: Metrics:\n", | |
| "COMET INFO: accuracy: 97.36\n", | |
| "COMET INFO: loss: tensor(0.0106, grad_fn=<NllLossBackward>)\n", | |
| "COMET INFO: Uploading stats to Comet before program termination (may take several seconds)\n" | |
| ] | |
| } | |
| ], | |
| "source": [ | |
| "# Run this cell once your model has completed to training to signal the end of the experiment \n", | |
| "experiment.end()" | |
| ] | |
| } | |
| ], | |
| "metadata": { | |
| "kernelspec": { | |
| "display_name": "Python 3", | |
| "language": "python", | |
| "name": "python3" | |
| }, | |
| "language_info": { | |
| "codemirror_mode": { | |
| "name": "ipython", | |
| "version": 3 | |
| }, | |
| "file_extension": ".py", | |
| "mimetype": "text/x-python", | |
| "name": "python", | |
| "nbconvert_exporter": "python", | |
| "pygments_lexer": "ipython3", | |
| "version": "3.6.5" | |
| } | |
| }, | |
| "nbformat": 4, | |
| "nbformat_minor": 2 | |
| } |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment