Performs a single optimization step. Arguments: closure (callable, optional): A closure that reevaluates the model and returns the loss.
(self, closure=None)
| 70 | |
| 71 | @torch.no_grad() |
| 72 | def step(self, closure=None): |
| 73 | """Performs a single optimization step. |
| 74 | |
| 75 | Arguments: |
| 76 | closure (callable, optional): A closure that reevaluates the model |
| 77 | and returns the loss. |
| 78 | """ |
| 79 | loss = None |
| 80 | if closure is not None: |
| 81 | with torch.enable_grad(): |
| 82 | loss = closure() |
| 83 | |
| 84 | for group in self.param_groups: |
| 85 | weight_decay = group['weight_decay'] |
| 86 | momentum = group['momentum'] |
| 87 | dampening = group['dampening'] |
| 88 | eta = group['eta'] |
| 89 | nesterov = group['nesterov'] |
| 90 | lr = group['lr'] |
| 91 | lars_exclude = group.get('lars_exclude', False) |
| 92 | |
| 93 | for p in group['params']: |
| 94 | if p.grad is None: |
| 95 | continue |
| 96 | |
| 97 | d_p = p.grad |
| 98 | |
| 99 | if lars_exclude: |
| 100 | local_lr = 1. |
| 101 | else: |
| 102 | weight_norm = torch.norm(p).item() |
| 103 | grad_norm = torch.norm(d_p).item() |
| 104 | # Compute local learning rate for this layer |
| 105 | local_lr = eta * weight_norm / \ |
| 106 | (grad_norm + weight_decay * weight_norm) |
| 107 | |
| 108 | actual_lr = local_lr * lr |
| 109 | d_p = d_p.add(p, alpha=weight_decay).mul(actual_lr) |
| 110 | if momentum != 0: |
| 111 | param_state = self.state[p] |
| 112 | if 'momentum_buffer' not in param_state: |
| 113 | buf = param_state['momentum_buffer'] = \ |
| 114 | torch.clone(d_p).detach() |
| 115 | else: |
| 116 | buf = param_state['momentum_buffer'] |
| 117 | buf.mul_(momentum).add_(d_p, alpha=1 - dampening) |
| 118 | if nesterov: |
| 119 | d_p = d_p.add(buf, alpha=momentum) |
| 120 | else: |
| 121 | d_p = buf |
| 122 | p.add_(-d_p) |
| 123 | |
| 124 | return loss |