| 82 | |
| 83 | |
| 84 | def sgd_optimizer(model, lr, momentum, weight_decay): |
| 85 | params = [] |
| 86 | for key, value in model.named_parameters(): |
| 87 | if not value.requires_grad: |
| 88 | continue |
| 89 | apply_weight_decay = weight_decay |
| 90 | apply_lr = lr |
| 91 | if value.ndimension() < 2: #TODO note this |
| 92 | apply_weight_decay = 0 |
| 93 | print('set weight decay=0 for {}'.format(key)) |
| 94 | if 'bias' in key: |
| 95 | apply_lr = 2 * lr # Just a Caffe-style common practice. Made no difference. |
| 96 | params += [{'params': [value], 'lr': apply_lr, 'weight_decay': apply_weight_decay}] |
| 97 | optimizer = torch.optim.SGD(params, lr, momentum=momentum) |
| 98 | return optimizer |
| 99 | |
| 100 | def main(): |
| 101 | args = parser.parse_args() |