Last active
April 8, 2020 23:07
-
-
Save zhezh/ccc7e7b70338c6b882e08113d7706530 to your computer and use it in GitHub Desktop.
[pytorch 分层设置学习率] #pytorch
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
| { | |
| "cells": [ | |
| { | |
| "cell_type": "markdown", | |
| "metadata": {}, | |
| "source": [ | |
| "分层设置学习率" | |
| ] | |
| }, | |
| { | |
| "cell_type": "code", | |
| "execution_count": 4, | |
| "metadata": {}, | |
| "outputs": [], | |
| "source": [ | |
| "import torch\n", | |
| "import torch.nn as nn\n", | |
| "import torch.nn.functional as F\n", | |
| "import torch.optim as optim" | |
| ] | |
| }, | |
| { | |
| "cell_type": "code", | |
| "execution_count": 3, | |
| "metadata": {}, | |
| "outputs": [ | |
| { | |
| "name": "stdout", | |
| "output_type": "stream", | |
| "text": [ | |
| "0.4.0\n" | |
| ] | |
| } | |
| ], | |
| "source": [ | |
| "print(torch.__version__)" | |
| ] | |
| }, | |
| { | |
| "cell_type": "code", | |
| "execution_count": 5, | |
| "metadata": {}, | |
| "outputs": [], | |
| "source": [ | |
| "# 构建一个简单多层网络结构\n", | |
| "class TwoLayerNet(torch.nn.Module):\n", | |
| " def __init__(self, D_in, H, D_out):\n", | |
| " \"\"\"\n", | |
| " In the constructor we instantiate two nn.Linear modules and assign them as\n", | |
| " member variables.\n", | |
| " \"\"\"\n", | |
| " super(TwoLayerNet, self).__init__()\n", | |
| " self.linear1 = torch.nn.Linear(D_in, H)\n", | |
| " self.linear2 = torch.nn.Linear(H, D_out)\n", | |
| "\n", | |
| " def forward(self, x):\n", | |
| " \"\"\"\n", | |
| " In the forward function we accept a Tensor of input data and we must return\n", | |
| " a Tensor of output data. We can use Modules defined in the constructor as\n", | |
| " well as arbitrary operators on Tensors.\n", | |
| " \"\"\"\n", | |
| " h_relu = F.relu(self.linear1(x))\n", | |
| " y_pred = self.linear2(h_relu)\n", | |
| " return y_pred" | |
| ] | |
| }, | |
| { | |
| "cell_type": "code", | |
| "execution_count": 7, | |
| "metadata": {}, | |
| "outputs": [], | |
| "source": [ | |
| "# N is batch size; D_in is input dimension;\n", | |
| "# H is hidden dimension; D_out is output dimension.\n", | |
| "N, D_in, H, D_out = 64, 1000, 100, 10" | |
| ] | |
| }, | |
| { | |
| "cell_type": "code", | |
| "execution_count": 8, | |
| "metadata": {}, | |
| "outputs": [], | |
| "source": [ | |
| "x = torch.randn(N, D_in)\n", | |
| "y = torch.randn(N, D_out)\n", | |
| "\n", | |
| "# Construct our model by instantiating the class defined above\n", | |
| "model = TwoLayerNet(D_in, H, D_out)\n", | |
| "\n", | |
| "# Construct our loss function and an Optimizer. The call to model.parameters()\n", | |
| "# in the SGD constructor will contain the learnable parameters of the two\n", | |
| "# nn.Linear modules which are members of the model.\n", | |
| "criterion = torch.nn.MSELoss(size_average=False)" | |
| ] | |
| }, | |
| { | |
| "cell_type": "markdown", | |
| "metadata": {}, | |
| "source": [ | |
| "查看模型的参数名称" | |
| ] | |
| }, | |
| { | |
| "cell_type": "code", | |
| "execution_count": 20, | |
| "metadata": {}, | |
| "outputs": [ | |
| { | |
| "name": "stdout", | |
| "output_type": "stream", | |
| "text": [ | |
| "参数名: linear1.weight , id: 140705409071936\n", | |
| "参数名: linear1.bias , id: 140705409072008\n", | |
| "参数名: linear2.weight , id: 140705409072296\n", | |
| "参数名: linear2.bias , id: 140705409072656\n" | |
| ] | |
| } | |
| ], | |
| "source": [ | |
| "for pname, p in model.named_parameters():\n", | |
| " print('参数名: {: <18}, id: {}'.format(pname, id(p)))" | |
| ] | |
| }, | |
| { | |
| "cell_type": "code", | |
| "execution_count": 21, | |
| "metadata": {}, | |
| "outputs": [], | |
| "source": [ | |
| "all_parameters = model.parameters()\n", | |
| "\n", | |
| "lin1_parameters = []\n", | |
| "for pname, p in model.named_parameters():\n", | |
| " if pname.find('linear1') >= 0:\n", | |
| " lin1_parameters.append(p)\n", | |
| "\n", | |
| "lin1_parameters_id = list(map(id, lin1_parameters))\n", | |
| "other_parameters = list(filter(lambda p: id(p) not in lin1_parameters_id,\n", | |
| " all_parameters))" | |
| ] | |
| }, | |
| { | |
| "cell_type": "markdown", | |
| "metadata": {}, | |
| "source": [ | |
| "现在获得了两组参数,一组是linear1,另一组是其他的(本程序中即linear2)" | |
| ] | |
| }, | |
| { | |
| "cell_type": "code", | |
| "execution_count": 23, | |
| "metadata": {}, | |
| "outputs": [ | |
| { | |
| "name": "stdout", | |
| "output_type": "stream", | |
| "text": [ | |
| "linear1组参数id: \n", | |
| "140705409071936\n", | |
| "140705409072008\n", | |
| "\n", | |
| "\n", | |
| "other组参数id: \n", | |
| "140705409072296\n", | |
| "140705409072656\n" | |
| ] | |
| } | |
| ], | |
| "source": [ | |
| "print('linear1组参数id: ')\n", | |
| "for p in lin1_parameters:\n", | |
| " print(id(p))\n", | |
| " \n", | |
| "print('\\n')\n", | |
| "print('other组参数id: ')\n", | |
| "for p in other_parameters:\n", | |
| " print(id(p))" | |
| ] | |
| }, | |
| { | |
| "cell_type": "markdown", | |
| "metadata": {}, | |
| "source": [ | |
| "构造optim" | |
| ] | |
| }, | |
| { | |
| "cell_type": "code", | |
| "execution_count": 24, | |
| "metadata": {}, | |
| "outputs": [], | |
| "source": [ | |
| "optimizer = optim.SGD([\n", | |
| " {'params': lin1_parameters},\n", | |
| " {'params': other_parameters, 'lr': 1e-3}\n", | |
| " ], lr=1e-4)\n", | |
| "# linear1层的学习率1e-4,其它层1e-3" | |
| ] | |
| }, | |
| { | |
| "cell_type": "markdown", | |
| "metadata": {}, | |
| "source": [ | |
| "训练网络" | |
| ] | |
| }, | |
| { | |
| "cell_type": "code", | |
| "execution_count": 25, | |
| "metadata": {}, | |
| "outputs": [ | |
| { | |
| "name": "stdout", | |
| "output_type": "stream", | |
| "text": [ | |
| "0 719.4832153320312\n", | |
| "1 618.053466796875\n", | |
| "2 556.1675415039062\n", | |
| "3 500.43212890625\n", | |
| "4 447.5682678222656\n", | |
| "5 396.4813537597656\n", | |
| "6 347.07867431640625\n", | |
| "7 300.09326171875\n", | |
| "8 256.21868896484375\n", | |
| "9 216.0878448486328\n", | |
| "10 180.35525512695312\n", | |
| "11 149.25791931152344\n", | |
| "12 122.67479705810547\n", | |
| "13 100.36481475830078\n", | |
| "14 81.89598083496094\n", | |
| "15 66.74868774414062\n", | |
| "16 54.42697525024414\n", | |
| "17 44.45567321777344\n", | |
| "18 36.39970397949219\n", | |
| "19 29.910730361938477\n", | |
| "20 24.662458419799805\n", | |
| "21 20.40622901916504\n", | |
| "22 16.947093963623047\n", | |
| "23 14.121529579162598\n", | |
| "24 11.805665969848633\n", | |
| "25 9.908512115478516\n", | |
| "26 8.346722602844238\n", | |
| "27 7.056419849395752\n", | |
| "28 5.984086036682129\n", | |
| "29 5.090843677520752\n", | |
| "30 4.344158172607422\n", | |
| "31 3.7168867588043213\n", | |
| "32 3.1883127689361572\n", | |
| "33 2.7413690090179443\n", | |
| "34 2.3627569675445557\n", | |
| "35 2.0411109924316406\n", | |
| "36 1.7670531272888184\n", | |
| "37 1.5327770709991455\n", | |
| "38 1.3317700624465942\n", | |
| "39 1.1588612794876099\n", | |
| "40 1.0100197792053223\n", | |
| "41 0.8820691108703613\n", | |
| "42 0.7714465856552124\n", | |
| "43 0.6757064461708069\n", | |
| "44 0.592779815196991\n", | |
| "45 0.5206921696662903\n", | |
| "46 0.45796385407447815\n", | |
| "47 0.403344064950943\n", | |
| "48 0.3557352125644684\n", | |
| "49 0.31413477659225464\n", | |
| "50 0.27769696712493896\n", | |
| "51 0.24573519825935364\n", | |
| "52 0.2176705002784729\n", | |
| "53 0.19302386045455933\n", | |
| "54 0.17133362591266632\n", | |
| "55 0.15221278369426727\n", | |
| "56 0.135351300239563\n", | |
| "57 0.12046048790216446\n", | |
| "58 0.1073007881641388\n", | |
| "59 0.09565050154924393\n", | |
| "60 0.08532743155956268\n", | |
| "61 0.07617135345935822\n", | |
| "62 0.06804817914962769\n", | |
| "63 0.06084805354475975\n", | |
| "64 0.05446131154894829\n", | |
| "65 0.04884392023086548\n", | |
| "66 0.043833404779434204\n", | |
| "67 0.039361849427223206\n", | |
| "68 0.03537042811512947\n", | |
| "69 0.03180677816271782\n", | |
| "70 0.028617050498723984\n", | |
| "71 0.025749675929546356\n", | |
| "72 0.023186029866337776\n", | |
| "73 0.02088870480656624\n", | |
| "74 0.018828196451067924\n", | |
| "75 0.016980471089482307\n", | |
| "76 0.015320717357099056\n", | |
| "77 0.013830263167619705\n", | |
| "78 0.012490017339587212\n", | |
| "79 0.011284894309937954\n", | |
| "80 0.01020009908825159\n", | |
| "81 0.00922376848757267\n", | |
| "82 0.008343766443431377\n", | |
| "83 0.007550488226115704\n", | |
| "84 0.0068353088572621346\n", | |
| "85 0.006190172396600246\n", | |
| "86 0.005607489496469498\n", | |
| "87 0.005081566050648689\n", | |
| "88 0.00460641598328948\n", | |
| "89 0.004176917020231485\n", | |
| "90 0.0037886006757616997\n", | |
| "91 0.0034373654052615166\n", | |
| "92 0.0031197601929306984\n", | |
| "93 0.0028322283178567886\n", | |
| "94 0.0025717862881720066\n", | |
| "95 0.002335888333618641\n", | |
| "96 0.0021221640054136515\n", | |
| "97 0.0019284778973087668\n", | |
| "98 0.0017529240576550364\n", | |
| "99 0.0015937236603349447\n", | |
| "100 0.0014493277994915843\n", | |
| "101 0.0013182209804654121\n", | |
| "102 0.0011992763029411435\n", | |
| "103 0.0010912807192653418\n", | |
| "104 0.0009932058164849877\n", | |
| "105 0.0009040983277373016\n", | |
| "106 0.0008233404951170087\n", | |
| "107 0.0007499091443605721\n", | |
| "108 0.0006831525824964046\n", | |
| "109 0.0006224379176273942\n", | |
| "110 0.0005672484403476119\n", | |
| "111 0.0005170325748622417\n", | |
| "112 0.0004713317903224379\n", | |
| "113 0.0004297299892641604\n", | |
| "114 0.0003918729198630899\n", | |
| "115 0.0003573991998564452\n", | |
| "116 0.0003260155499447137\n", | |
| "117 0.0002974196686409414\n", | |
| "118 0.0002713669091463089\n", | |
| "119 0.0002476317167747766\n", | |
| "120 0.00022600177908316255\n", | |
| "121 0.0002062939602183178\n", | |
| "122 0.00018832141358871013\n", | |
| "123 0.0001719275169307366\n", | |
| "124 0.0001569933956488967\n", | |
| "125 0.00014336439198814332\n", | |
| "126 0.00013093784218654037\n", | |
| "127 0.0001196042139781639\n", | |
| "128 0.0001092585880542174\n", | |
| "129 9.982137999031693e-05\n", | |
| "130 9.120586764765903e-05\n", | |
| "131 8.33444792078808e-05\n", | |
| "132 7.616882066940889e-05\n", | |
| "133 6.961503822822124e-05\n", | |
| "134 6.363449210766703e-05\n", | |
| "135 5.817100827698596e-05\n", | |
| "136 5.318074545357376e-05\n", | |
| "137 4.8620247980579734e-05\n", | |
| "138 4.4461063225753605e-05\n", | |
| "139 4.0658000216353685e-05\n", | |
| "140 3.7181245716055855e-05\n", | |
| "141 3.400376590434462e-05\n", | |
| "142 3.110080797341652e-05\n", | |
| "143 2.8448537705116905e-05\n", | |
| "144 2.6025612896773964e-05\n", | |
| "145 2.3810694983694702e-05\n", | |
| "146 2.1784513592137955e-05\n", | |
| "147 1.9931732822442427e-05\n", | |
| "148 1.8238761185784824e-05\n", | |
| "149 1.669217635935638e-05\n", | |
| "150 1.5273904864443466e-05\n", | |
| "151 1.3978798961034045e-05\n", | |
| "152 1.279618481930811e-05\n", | |
| "153 1.1711815204762388e-05\n", | |
| "154 1.071973474608967e-05\n", | |
| "155 9.814746590564027e-06\n", | |
| "156 8.985691238194704e-06\n", | |
| "157 8.226681529777125e-06\n", | |
| "158 7.533013558713719e-06\n", | |
| "159 6.897260846017161e-06\n", | |
| "160 6.3151273934636265e-06\n", | |
| "161 5.784675977338338e-06\n", | |
| "162 5.296439212543191e-06\n", | |
| "163 4.851282938034274e-06\n", | |
| "164 4.442661975190276e-06\n", | |
| "165 4.068862381245708e-06\n", | |
| "166 3.7270233406161424e-06\n", | |
| "167 3.413488684600452e-06\n", | |
| "168 3.1275621950044297e-06\n", | |
| "169 2.8650440526689636e-06\n", | |
| "170 2.624645276227966e-06\n", | |
| "171 2.405057784926612e-06\n", | |
| "172 2.2035642359696794e-06\n", | |
| "173 2.019183966694982e-06\n", | |
| "174 1.850062517405604e-06\n", | |
| "175 1.6953827071120031e-06\n", | |
| "176 1.5531543340330245e-06\n", | |
| "177 1.4232090279620024e-06\n", | |
| "178 1.3044416391494451e-06\n", | |
| "179 1.1956101388932439e-06\n", | |
| "180 1.0957942322420422e-06\n", | |
| "181 1.0046904890259611e-06\n", | |
| "182 9.207004154632159e-07\n", | |
| "183 8.440935062026256e-07\n", | |
| "184 7.73405361087498e-07\n", | |
| "185 7.091608722475939e-07\n", | |
| "186 6.500075642179581e-07\n", | |
| "187 5.959356030871277e-07\n", | |
| "188 5.460437932924833e-07\n", | |
| "189 5.007885874874773e-07\n", | |
| "190 4.5921461833131616e-07\n", | |
| "191 4.2098395169887226e-07\n", | |
| "192 3.8619594988631434e-07\n", | |
| "193 3.539971658028662e-07\n", | |
| "194 3.244428512516606e-07\n", | |
| "195 2.975311303998751e-07\n", | |
| "196 2.726736170188815e-07\n", | |
| "197 2.5038809781108284e-07\n", | |
| "198 2.2951981293317658e-07\n", | |
| "199 2.1044718323537381e-07\n", | |
| "200 1.931230571017295e-07\n", | |
| "201 1.7704486765524052e-07\n", | |
| "202 1.625433725394032e-07\n", | |
| "203 1.4899816846991598e-07\n", | |
| "204 1.3666438292148086e-07\n", | |
| "205 1.2532166238088394e-07\n", | |
| "206 1.1494875451489861e-07\n", | |
| "207 1.054767295727288e-07\n", | |
| "208 9.676008971837291e-08\n", | |
| "209 8.872893886291422e-08\n", | |
| "210 8.135914697504631e-08\n", | |
| "211 7.470005414234038e-08\n", | |
| "212 6.852687306491134e-08\n", | |
| "213 6.293034005011577e-08\n", | |
| "214 5.7790256136058815e-08\n", | |
| "215 5.302708672161316e-08\n", | |
| "216 4.8682117892440147e-08\n", | |
| "217 4.453647051150256e-08\n", | |
| "218 4.099248585021087e-08\n", | |
| "219 3.761395817036828e-08\n", | |
| "220 3.453242669593237e-08\n", | |
| "221 3.170368501059784e-08\n", | |
| "222 2.915122720992258e-08\n", | |
| "223 2.6717120960029206e-08\n", | |
| "224 2.451883673870725e-08\n", | |
| "225 2.2580659120308155e-08\n", | |
| "226 2.0710798409595554e-08\n", | |
| "227 1.9040477639009623e-08\n", | |
| "228 1.7526446072224644e-08\n", | |
| "229 1.6122124080197864e-08\n", | |
| "230 1.4806063042271944e-08\n", | |
| "231 1.3633435713700237e-08\n", | |
| "232 1.2587334730085331e-08\n", | |
| "233 1.1583493275679757e-08\n", | |
| "234 1.0632591695980409e-08\n", | |
| "235 9.816670143436568e-09\n", | |
| "236 9.059893280038978e-09\n", | |
| "237 8.38562641547469e-09\n", | |
| "238 7.752122499482539e-09\n", | |
| "239 7.144894009769587e-09\n", | |
| "240 6.591914125664289e-09\n", | |
| "241 6.152251152968802e-09\n", | |
| "242 5.6817590632363135e-09\n", | |
| "243 5.275397008119853e-09\n", | |
| "244 4.890560845183245e-09\n", | |
| "245 4.5295536210687715e-09\n", | |
| "246 4.211996085246028e-09\n", | |
| "247 3.908982026956664e-09\n", | |
| "248 3.6504110845214655e-09\n", | |
| "249 3.4168203821849374e-09\n", | |
| "250 3.187831110196271e-09\n", | |
| "251 2.98542235377397e-09\n", | |
| "252 2.8059994328089033e-09\n", | |
| "253 2.626949102690901e-09\n", | |
| "254 2.4697992540012592e-09\n", | |
| "255 2.3114088421039014e-09\n", | |
| "256 2.1845512065965522e-09\n", | |
| "257 2.0542922918309614e-09\n", | |
| "258 1.9309427390368228e-09\n", | |
| "259 1.8381356436947272e-09\n", | |
| "260 1.7248911188261218e-09\n", | |
| "261 1.6398804536521538e-09\n", | |
| "262 1.5482324311477669e-09\n", | |
| "263 1.4635379574912122e-09\n", | |
| "264 1.3910612661760524e-09\n", | |
| "265 1.317505660125562e-09\n", | |
| "266 1.2539014271339965e-09\n", | |
| "267 1.2011907024600532e-09\n", | |
| "268 1.1445275838184443e-09\n", | |
| "269 1.0953595808160799e-09\n", | |
| "270 1.0399190397691882e-09\n", | |
| "271 1.0023235574863065e-09\n", | |
| "272 9.588654314995892e-10\n", | |
| "273 9.140234680238279e-10\n", | |
| "274 8.714531318787522e-10\n", | |
| "275 8.341509150078252e-10\n", | |
| "276 8.028351317079796e-10\n", | |
| "277 7.760427300773642e-10\n", | |
| "278 7.34666716351029e-10\n", | |
| "279 7.090298348444435e-10\n", | |
| "280 6.824536491478739e-10\n", | |
| "281 6.50852871597607e-10\n", | |
| "282 6.332582236368012e-10\n", | |
| "283 6.151034126489208e-10\n", | |
| "284 5.901392152729557e-10\n", | |
| "285 5.691445092992353e-10\n", | |
| "286 5.507771461132904e-10\n", | |
| "287 5.32691946109054e-10\n", | |
| "288 5.124157764768711e-10\n", | |
| "289 4.940373665718312e-10\n", | |
| "290 4.766944616818591e-10\n", | |
| "291 4.6186365842970645e-10\n", | |
| "292 4.470791514776806e-10\n", | |
| "293 4.3581016573313036e-10\n", | |
| "294 4.2106490516502504e-10\n", | |
| "295 4.066213477038616e-10\n", | |
| "296 3.981817098264173e-10\n", | |
| "297 3.8970893179168797e-10\n", | |
| "298 3.72095521061766e-10\n", | |
| "299 3.6663283520255163e-10\n", | |
| "300 3.573504825382656e-10\n", | |
| "301 3.487677924240984e-10\n", | |
| "302 3.3633168472491093e-10\n", | |
| "303 3.290345773621084e-10\n", | |
| "304 3.224511213595349e-10\n", | |
| "305 3.122619940398863e-10\n", | |
| "306 3.0483571222816863e-10\n", | |
| "307 2.9414332081145744e-10\n", | |
| "308 2.8667790363812173e-10\n", | |
| "309 2.8062335788447967e-10\n", | |
| "310 2.762368667141857e-10\n", | |
| "311 2.6522151141961103e-10\n", | |
| "312 2.57844356976733e-10\n", | |
| "313 2.5349577992273e-10\n", | |
| "314 2.478507399317209e-10\n", | |
| "315 2.41882402995941e-10\n", | |
| "316 2.3629426193494396e-10\n", | |
| "317 2.3070964583205011e-10\n", | |
| "318 2.264485127190241e-10\n", | |
| "319 2.2164882429454025e-10\n", | |
| "320 2.1533744232193897e-10\n", | |
| "321 2.097950702051321e-10\n", | |
| "322 2.064158427517171e-10\n", | |
| "323 1.990261011552974e-10\n", | |
| "324 1.9497652103961371e-10\n", | |
| "325 1.9140954099494678e-10\n", | |
| "326 1.8693983860895713e-10\n", | |
| "327 1.8207133023473432e-10\n", | |
| "328 1.783010128431073e-10\n", | |
| "329 1.763045681668629e-10\n", | |
| "330 1.7145321273837055e-10\n", | |
| "331 1.6931771262829187e-10\n", | |
| "332 1.6673835923075586e-10\n", | |
| "333 1.6416644432748484e-10\n", | |
| "334 1.6168200112076647e-10\n", | |
| "335 1.5868138747432425e-10\n", | |
| "336 1.5622357574240908e-10\n", | |
| "337 1.5223147742382537e-10\n", | |
| "338 1.4958895233618819e-10\n", | |
| "339 1.4682752236261365e-10\n", | |
| "340 1.436487179207191e-10\n", | |
| "341 1.4157740257925155e-10\n", | |
| "342 1.410829092440835e-10\n", | |
| "343 1.380528469319131e-10\n", | |
| "344 1.3580145341585137e-10\n", | |
| "345 1.341790289988154e-10\n", | |
| "346 1.3181483682345174e-10\n", | |
| "347 1.3028327028319353e-10\n", | |
| "348 1.2861381404327688e-10\n", | |
| "349 1.2620478273550617e-10\n", | |
| "350 1.2542905603041277e-10\n", | |
| "351 1.238803642999997e-10\n", | |
| "352 1.218108947043106e-10\n" | |
| ] | |
| }, | |
| { | |
| "name": "stdout", | |
| "output_type": "stream", | |
| "text": [ | |
| "353 1.2094443502252972e-10\n", | |
| "354 1.1777642749954964e-10\n", | |
| "355 1.1760770135538223e-10\n", | |
| "356 1.1544241951266798e-10\n", | |
| "357 1.1402352673162142e-10\n", | |
| "358 1.1242246023002167e-10\n", | |
| "359 1.1167516911214648e-10\n", | |
| "360 1.0990484911044263e-10\n", | |
| "361 1.0966394459188678e-10\n", | |
| "362 1.0963921437401325e-10\n", | |
| "363 1.0603967703914918e-10\n", | |
| "364 1.03834024711702e-10\n", | |
| "365 1.039718450224214e-10\n", | |
| "366 1.0166703590108739e-10\n", | |
| "367 1.0114989401621699e-10\n", | |
| "368 9.834176528666916e-11\n", | |
| "369 9.60351104195567e-11\n", | |
| "370 9.521219923591673e-11\n", | |
| "371 9.493408836824813e-11\n", | |
| "372 9.308730175572322e-11\n", | |
| "373 9.198990180703248e-11\n", | |
| "374 9.021075553228286e-11\n", | |
| "375 8.961704989207675e-11\n", | |
| "376 8.807096718577156e-11\n", | |
| "377 8.61079402225684e-11\n", | |
| "378 8.594966405262028e-11\n", | |
| "379 8.452996635988086e-11\n", | |
| "380 8.392590788997012e-11\n", | |
| "381 8.171358034658738e-11\n", | |
| "382 8.029973908030286e-11\n", | |
| "383 7.93237697749305e-11\n", | |
| "384 7.853767636234465e-11\n", | |
| "385 7.892052289459883e-11\n", | |
| "386 7.902740961629462e-11\n", | |
| "387 7.743728575038133e-11\n", | |
| "388 7.645977601056231e-11\n", | |
| "389 7.571010485207808e-11\n", | |
| "390 7.520549460959813e-11\n", | |
| "391 7.54181994633285e-11\n", | |
| "392 7.461733314562125e-11\n", | |
| "393 7.3604816686057e-11\n", | |
| "394 7.3249614707116e-11\n", | |
| "395 7.211352348601707e-11\n", | |
| "396 7.178711791677728e-11\n", | |
| "397 7.167869769952873e-11\n", | |
| "398 7.025070108968023e-11\n", | |
| "399 6.825126575016327e-11\n", | |
| "400 6.793333950927405e-11\n", | |
| "401 6.772558902579107e-11\n", | |
| "402 6.730412061006774e-11\n", | |
| "403 6.625886644906487e-11\n", | |
| "404 6.461721435702117e-11\n", | |
| "405 6.47099734907286e-11\n", | |
| "406 6.391782242376465e-11\n", | |
| "407 6.264266882993752e-11\n", | |
| "408 6.173965505507084e-11\n", | |
| "409 6.116362971653189e-11\n", | |
| "410 6.077430919626536e-11\n", | |
| "411 6.030231869402769e-11\n", | |
| "412 5.947273923334606e-11\n", | |
| "413 5.867167862660949e-11\n", | |
| "414 5.855109452834739e-11\n", | |
| "415 5.7634869099487673e-11\n", | |
| "416 5.798097418852066e-11\n", | |
| "417 5.714314438298729e-11\n", | |
| "418 5.7525491314880384e-11\n", | |
| "419 5.753773499317383e-11\n", | |
| "420 5.72689846933816e-11\n", | |
| "421 5.631775948367057e-11\n", | |
| "422 5.598362051717487e-11\n", | |
| "423 5.5867376697049664e-11\n", | |
| "424 5.629419153052595e-11\n", | |
| "425 5.5520841396594633e-11\n", | |
| "426 5.46322084793438e-11\n", | |
| "427 5.4010414196614676e-11\n", | |
| "428 5.341607364761636e-11\n", | |
| "429 5.3519549902958374e-11\n", | |
| "430 5.247411186126705e-11\n", | |
| "431 5.307565845158457e-11\n", | |
| "432 5.2371048470112314e-11\n", | |
| "433 5.2449510012930745e-11\n", | |
| "434 5.1879195384074706e-11\n", | |
| "435 5.1574945703070085e-11\n", | |
| "436 5.1512395043973314e-11\n", | |
| "437 5.08322897663227e-11\n", | |
| "438 5.057060326052465e-11\n", | |
| "439 4.9787240302689995e-11\n", | |
| "440 4.823476340565236e-11\n", | |
| "441 4.8459992962879284e-11\n", | |
| "442 4.797246627719076e-11\n", | |
| "443 4.824013410953398e-11\n", | |
| "444 4.7559997606860804e-11\n", | |
| "445 4.690832444698145e-11\n", | |
| "446 4.619358368040949e-11\n", | |
| "447 4.583974866356755e-11\n", | |
| "448 4.605983650041168e-11\n", | |
| "449 4.5755635391664384e-11\n", | |
| "450 4.547212259509159e-11\n", | |
| "451 4.5017298916372184e-11\n", | |
| "452 4.5089352390670356e-11\n", | |
| "453 4.437321690642371e-11\n", | |
| "454 4.417576374149412e-11\n", | |
| "455 4.4186581477090314e-11\n", | |
| "456 4.324371069563959e-11\n", | |
| "457 4.303992579002269e-11\n", | |
| "458 4.266882333570088e-11\n", | |
| "459 4.2662408328286716e-11\n", | |
| "460 4.2061472360632735e-11\n", | |
| "461 4.153392560435343e-11\n", | |
| "462 4.1709347781138106e-11\n", | |
| "463 4.15860852698291e-11\n", | |
| "464 4.1531614952683427e-11\n", | |
| "465 4.106182407981329e-11\n", | |
| "466 4.0790346794716825e-11\n", | |
| "467 4.062897934753451e-11\n", | |
| "468 3.990976646384148e-11\n", | |
| "469 4.011613263799063e-11\n", | |
| "470 3.9666533946380866e-11\n", | |
| "471 3.945664281412853e-11\n", | |
| "472 3.835248785222234e-11\n", | |
| "473 3.8230120458226935e-11\n", | |
| "474 3.747497104300557e-11\n", | |
| "475 3.7175280215295814e-11\n", | |
| "476 3.713800100779707e-11\n", | |
| "477 3.680154792018442e-11\n", | |
| "478 3.743382687160235e-11\n", | |
| "479 3.680740781608627e-11\n", | |
| "480 3.6479482629081517e-11\n", | |
| "481 3.598047901287593e-11\n", | |
| "482 3.572271645158054e-11\n", | |
| "483 3.5526505348659754e-11\n", | |
| "484 3.541624979397362e-11\n", | |
| "485 3.4761606787503396e-11\n", | |
| "486 3.4769877949036854e-11\n", | |
| "487 3.460022199308632e-11\n", | |
| "488 3.453215144388899e-11\n", | |
| "489 3.4166957457726355e-11\n", | |
| "490 3.428968220475781e-11\n", | |
| "491 3.439980245101282e-11\n", | |
| "492 3.4364174700263206e-11\n", | |
| "493 3.4181008717881767e-11\n", | |
| "494 3.368537393466653e-11\n", | |
| "495 3.368653272994848e-11\n", | |
| "496 3.291827227469568e-11\n", | |
| "497 3.2805473615393765e-11\n", | |
| "498 3.3228246543171025e-11\n", | |
| "499 3.3127479925898484e-11\n" | |
| ] | |
| } | |
| ], | |
| "source": [ | |
| "for t in range(500):\n", | |
| " # Forward pass: Compute predicted y by passing x to the model\n", | |
| " y_pred = model(x)\n", | |
| "\n", | |
| " # Compute and print loss\n", | |
| " loss = criterion(y_pred, y)\n", | |
| " print(t, loss.item())\n", | |
| "\n", | |
| " # Zero gradients, perform a backward pass, and update the weights.\n", | |
| " optimizer.zero_grad()\n", | |
| " loss.backward()\n", | |
| " optimizer.step()" | |
| ] | |
| }, | |
| { | |
| "cell_type": "code", | |
| "execution_count": null, | |
| "metadata": {}, | |
| "outputs": [], | |
| "source": [] | |
| } | |
| ], | |
| "metadata": { | |
| "kernelspec": { | |
| "display_name": "Python 3", | |
| "language": "python", | |
| "name": "python3" | |
| }, | |
| "language_info": { | |
| "codemirror_mode": { | |
| "name": "ipython", | |
| "version": 3 | |
| }, | |
| "file_extension": ".py", | |
| "mimetype": "text/x-python", | |
| "name": "python", | |
| "nbconvert_exporter": "python", | |
| "pygments_lexer": "ipython3", | |
| "version": "3.6.2" | |
| } | |
| }, | |
| "nbformat": 4, | |
| "nbformat_minor": 2 | |
| } |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment