Last active
October 26, 2018 09:36
-
-
Save diaoenmao/1204cc684ec11b20b0cc0126e5f4b1b7 to your computer and use it in GitHub Desktop.
reproduce unused param error
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
| import torch | |
| import torch.nn as nn | |
| import torch.optim as optim | |
| import torch.nn.functional as F | |
| torch.manual_seed(2809) | |
| def check_params(modelA, modelB): | |
| for key in modelA.state_dict(): | |
| is_equal = (modelA.state_dict()[key]==modelB.state_dict()[key]).all() | |
| print('Checking {}, is equal = {}'.format(key, is_equal)) | |
| if not is_equal: | |
| print('ERROR!') | |
| def check_grads(modelA, modelB): | |
| for name, module in modelA.named_parameters(): | |
| splitted_name = name.split('.') | |
| if(len(name.split('.'))==2): | |
| module_name = name.split('.')[0] | |
| param_name = name.split('.')[1] | |
| modelB_grad = getattr(getattr(modelB, module_name), param_name).grad | |
| is_equal = (module.grad==modelB_grad).all() | |
| print('Gradient for {} is equal {}'.format(name, is_equal)) | |
| if not is_equal: | |
| print('ERROR!') | |
| elif(len(name.split('.'))==3): | |
| module_name = name.split('.')[0] | |
| id_name = name.split('.')[1] | |
| param_name = name.split('.')[2] | |
| modelB_grad = getattr(getattr(getattr(modelB, module_name), id_name),param_name).grad | |
| is_equal = (module.grad==modelB_grad).all() | |
| print('Gradient for {} is equal {}'.format(name, is_equal)) | |
| if not is_equal: | |
| print('ERROR!') | |
| def copy_params(modelA, modelB): | |
| modelA_dict = modelA.state_dict() | |
| modelB_dict = modelB.state_dict() | |
| equal_dict = {k: v for k, v in modelB_dict.items() if k in modelA_dict} | |
| modelA.load_state_dict(equal_dict) | |
| class MyModel(nn.Module): | |
| def __init__(self): | |
| super(MyModel, self).__init__() | |
| self.conv_list_1 = nn.ModuleList([]) | |
| self.conv_list_2 = nn.ModuleList([]) | |
| self.pool1 = nn.MaxPool2d(2) | |
| self.pool2 = nn.MaxPool2d(2) | |
| self.fc = nn.Linear(12*6*6, 2) | |
| self.conv_list_1.extend(self.make_layers_1()) | |
| self.conv_list_2.extend(self.make_layers_2()) | |
| def make_layers_1(self): | |
| conv1 = nn.Conv2d(3, 6, 3, 1, 1) | |
| return [conv1] | |
| def make_layers_2(self): | |
| conv2 = nn.Conv2d(6, 12, 3, 1, 1) | |
| return [conv2] | |
| def forward(self, x): | |
| x = F.relu(self.conv_list_1[0](x)) | |
| x = self.pool1(x) | |
| x = F.relu(self.conv_list_2[0](x)) | |
| x = self.pool2(x) | |
| x = x.view(x.size(0), -1) | |
| x = self.fc(x) | |
| return x | |
| class MyModelUnused(nn.Module): | |
| def __init__(self): | |
| super(MyModelUnused, self).__init__() | |
| self.conv_list_1 = nn.ModuleList([]) | |
| self.conv_list_2 = nn.ModuleList([]) | |
| self.pool1 = nn.MaxPool2d(2) | |
| self.pool2 = nn.MaxPool2d(2) | |
| self.fc = nn.Linear(12*6*6, 2) | |
| self.conv_list_1.extend(self.make_layers_1()) | |
| self.conv_list_2.extend(self.make_layers_2()) | |
| def make_layers_1(self): | |
| conv1 = nn.Conv2d(3, 6, 3, 1, 1) | |
| conv_unused1 = nn.Conv2d(12, 24, 3, 1, 1) | |
| return [conv1,conv_unused1] | |
| def make_layers_2(self): | |
| conv2 = nn.Conv2d(6, 12, 3, 1, 1) | |
| conv_unused2 = nn.Conv2d(24, 12, 3, 1, 1) | |
| return [conv2,conv_unused2] | |
| def forward(self, x): | |
| x = F.relu(self.conv_list_1[0](x)) | |
| x = self.pool1(x) | |
| x = F.relu(self.conv_list_2[0](x)) | |
| x = self.pool2(x) | |
| x = x.view(x.size(0), -1) | |
| x = self.fc(x) | |
| return x | |
| print('Copy test') | |
| modelA = MyModel() | |
| modelB = MyModelUnused() | |
| copy_params(modelA, modelB) | |
| # Check weights for equality | |
| check_params(modelA, modelB) | |
| print('Training test') | |
| x = torch.randn(10, 3, 24, 24) | |
| target = torch.empty(10, dtype=torch.long).random_(2) | |
| criterion = nn.CrossEntropyLoss() | |
| torch.manual_seed(2809) | |
| modelA = MyModel() | |
| torch.manual_seed(2809) | |
| modelB = MyModelUnused() | |
| print(modelA.state_dict().keys()) | |
| print(modelB.state_dict().keys()) | |
| # Check weights for equality | |
| check_params(modelA, modelB) | |
| optimizerA = optim.SGD(modelA.parameters(), lr=1e-3) | |
| optimizerB = optim.SGD(modelB.parameters(), lr=1e-3) | |
| for epoch in range(10): | |
| print('Checking epoch {}'.format(epoch)) | |
| optimizerA.zero_grad() | |
| optimizerB.zero_grad() | |
| check_params(modelA, modelB) | |
| outputA = modelA(x) | |
| outputB = modelB(x) | |
| (outputA==outputB).all() | |
| lossA = criterion(outputA, target) | |
| lossB = criterion(outputB, target) | |
| (lossA==lossB).all() | |
| lossA.backward() | |
| lossB.backward() | |
| check_grads(modelA, modelB) | |
| optimizerA.step() | |
| optimizerB.step() |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment