Performs a single optimization step. Arguments: closure (callable, optional): A closure that reevaluates the model and returns the loss.
(self, closure=None)
| 235 | return lr |
| 236 | |
| 237 | def step(self, closure=None): |
| 238 | """Performs a single optimization step. |
| 239 | |
| 240 | Arguments: |
| 241 | closure (callable, optional): A closure that reevaluates the model |
| 242 | and returns the loss. |
| 243 | """ |
| 244 | loss = None |
| 245 | if closure is not None: |
| 246 | loss = closure() |
| 247 | |
| 248 | for group in self.param_groups: |
| 249 | for p in group['params']: |
| 250 | if p.grad is None: |
| 251 | continue |
| 252 | grad = p.grad.data |
| 253 | if grad.is_sparse: |
| 254 | raise RuntimeError('Adam does not support sparse gradients, please consider SparseAdam instead') |
| 255 | |
| 256 | state = self.state[p] |
| 257 | |
| 258 | # State initialization |
| 259 | if len(state) == 0: |
| 260 | state['step'] = 0 |
| 261 | # Exponential moving average of gradient values |
| 262 | state['next_m'] = torch.zeros_like(p.data) |
| 263 | # Exponential moving average of squared gradient values |
| 264 | state['next_v'] = torch.zeros_like(p.data) |
| 265 | |
| 266 | next_m, next_v = state['next_m'], state['next_v'] |
| 267 | beta1, beta2 = group['b1'], group['b2'] |
| 268 | |
| 269 | # Add grad clipping |
| 270 | if group['max_grad_norm'] > 0: |
| 271 | clip_grad_norm_(p, group['max_grad_norm']) |
| 272 | |
| 273 | # Decay the first and second moment running average coefficient |
| 274 | # In-place operations to update the averages at the same time |
| 275 | next_m.mul_(beta1).add_(1 - beta1, grad) |
| 276 | next_v.mul_(beta2).addcmul_(1 - beta2, grad, grad) |
| 277 | update = next_m / (next_v.sqrt() + group['e']) |
| 278 | |
| 279 | # Just adding the square of the weights to the loss function is *not* |
| 280 | # the correct way of using L2 regularization/weight decay with Adam, |
| 281 | # since that will interact with the m and v parameters in strange ways. |
| 282 | # |
| 283 | # Instead we want to decay the weights in a manner that doesn't interact |
| 284 | # with the m/v parameters. This is equivalent to adding the square |
| 285 | # of the weights to the loss with plain (non-momentum) SGD. |
| 286 | if group['weight_decay'] > 0.0: |
| 287 | update += group['weight_decay'] * p.data |
| 288 | |
| 289 | lr_scheduled = group['lr'] |
| 290 | lr_scheduled *= group['schedule'].get_lr(state['step']) |
| 291 | |
| 292 | update_with_lr = lr_scheduled * update |
| 293 | p.data.add_(-update_with_lr) |
| 294 |
no test coverage detected