Performs a single optimization step. Arguments: closure (callable, optional): A closure that reevaluates the model and returns the loss.
(self, closure=None)
| 120 | return lr |
| 121 | |
| 122 | def step(self, closure=None): |
| 123 | """Performs a single optimization step. |
| 124 | Arguments: |
| 125 | closure (callable, optional): A closure that reevaluates the model |
| 126 | and returns the loss. |
| 127 | """ |
| 128 | loss = None |
| 129 | if closure is not None: |
| 130 | loss = closure() |
| 131 | |
| 132 | for group in self.param_groups: |
| 133 | for p in group['params']: |
| 134 | if p.grad is None: |
| 135 | continue |
| 136 | grad = p.grad.data |
| 137 | if grad.is_sparse: |
| 138 | raise RuntimeError( |
| 139 | 'Adam does not support sparse gradients, please consider SparseAdam instead') |
| 140 | state = self.state[p] |
| 141 | device = p.device |
| 142 | # State initialization |
| 143 | if len(state) == 0: |
| 144 | state['step'] = 0 |
| 145 | # Exponential moving average of gradient values |
| 146 | state['next_m'] = torch.zeros_like(p.data) |
| 147 | # Exponential moving average of squared gradient values |
| 148 | state['next_v'] = torch.zeros_like(p.data) |
| 149 | if 'is_embedding' in group and group['is_embedding']: |
| 150 | vocab_size = p.data.size(0) |
| 151 | state['b1_correction'] = torch.ones([vocab_size], |
| 152 | device=device) |
| 153 | state['b1_correction'][:] = group['b1'] |
| 154 | state['b2_correction'] = torch.ones([vocab_size], |
| 155 | device=device) |
| 156 | state['b2_correction'][:] = group['b2'] |
| 157 | state['ones'] = torch.ones([vocab_size], device=device) |
| 158 | state['zeros'] = torch.zeros([vocab_size], |
| 159 | device=device) |
| 160 | |
| 161 | state['b1'] = torch.ones([vocab_size], device=device) |
| 162 | state['b1'][:] = group['b1'] |
| 163 | state['b2'] = torch.ones([vocab_size], device=device) |
| 164 | state['b2'][:] = group['b2'] |
| 165 | |
| 166 | next_m, next_v = state['next_m'], state['next_v'] |
| 167 | beta1, beta2 = group['b1'], group['b2'] |
| 168 | |
| 169 | # Add grad clipping |
| 170 | if group['max_grad_norm'] > 0: |
| 171 | clip_grad_norm_(p, group['max_grad_norm']) |
| 172 | |
| 173 | # Decay the first and second moment running average coefficient |
| 174 | # In-place operations to update the averages at the same time |
| 175 | next_m.mul_(beta1).add_(1 - beta1, grad) |
| 176 | next_v.mul_(beta2).addcmul_(1 - beta2, grad, grad) |
| 177 | update = next_m / (next_v.sqrt() + group['e']) |
| 178 | |
| 179 | # Just adding the square of the weights to the loss function is *not* |