MCPcopy
hub / github.com/appvision-ai/fast-bert / step

Method step

fast_bert/optimization.py:237–302  ·  view source on GitHub ↗

Performs a single optimization step. Arguments: closure (callable, optional): A closure that reevaluates the model and returns the loss.

(self, closure=None)

Source from the content-addressed store, hash-verified

235 return lr
236
237 def step(self, closure=None):
238 """Performs a single optimization step.
239
240 Arguments:
241 closure (callable, optional): A closure that reevaluates the model
242 and returns the loss.
243 """
244 loss = None
245 if closure is not None:
246 loss = closure()
247
248 for group in self.param_groups:
249 for p in group['params']:
250 if p.grad is None:
251 continue
252 grad = p.grad.data
253 if grad.is_sparse:
254 raise RuntimeError('Adam does not support sparse gradients, please consider SparseAdam instead')
255
256 state = self.state[p]
257
258 # State initialization
259 if len(state) == 0:
260 state['step'] = 0
261 # Exponential moving average of gradient values
262 state['next_m'] = torch.zeros_like(p.data)
263 # Exponential moving average of squared gradient values
264 state['next_v'] = torch.zeros_like(p.data)
265
266 next_m, next_v = state['next_m'], state['next_v']
267 beta1, beta2 = group['b1'], group['b2']
268
269 # Add grad clipping
270 if group['max_grad_norm'] > 0:
271 clip_grad_norm_(p, group['max_grad_norm'])
272
273 # Decay the first and second moment running average coefficient
274 # In-place operations to update the averages at the same time
275 next_m.mul_(beta1).add_(1 - beta1, grad)
276 next_v.mul_(beta2).addcmul_(1 - beta2, grad, grad)
277 update = next_m / (next_v.sqrt() + group['e'])
278
279 # Just adding the square of the weights to the loss function is *not*
280 # the correct way of using L2 regularization/weight decay with Adam,
281 # since that will interact with the m and v parameters in strange ways.
282 #
283 # Instead we want to decay the weights in a manner that doesn't interact
284 # with the m/v parameters. This is equivalent to adding the square
285 # of the weights to the loss with plain (non-momentum) SGD.
286 if group['weight_decay'] > 0.0:
287 update += group['weight_decay'] * p.data
288
289 lr_scheduled = group['lr']
290 lr_scheduled *= group['schedule'].get_lr(state['step'])
291
292 update_with_lr = lr_scheduled * update
293 p.data.add_(-update_with_lr)
294

Callers 9

fitMethod · 0.45
fitMethod · 0.45
lr_findMethod · 0.45
_train_batchMethod · 0.45
fitMethod · 0.45
fitMethod · 0.45
lr_findMethod · 0.45
_train_batchMethod · 0.45
plot_pr_curveFunction · 0.45

Calls 1

get_lrMethod · 0.45

Tested by

no test coverage detected