Single training step.
(data_iterator, model, optimizer, lr_scheduler,
args, timers)
| 268 | |
| 269 | |
| 270 | def train_step(data_iterator, model, optimizer, lr_scheduler, |
| 271 | args, timers): |
| 272 | """Single training step.""" |
| 273 | |
| 274 | # Forward model for one step. |
| 275 | timers('forward').start() |
| 276 | lm_loss, nsp_loss = forward_step(data_iterator, model, |
| 277 | args, timers) |
| 278 | timers('forward').stop() |
| 279 | |
| 280 | # Calculate gradients, reduce across processes, and clip. |
| 281 | timers('backward').start() |
| 282 | lm_loss_reduced, nsp_loss_reduced = backward_step(optimizer, model, lm_loss, |
| 283 | nsp_loss, args) |
| 284 | timers('backward').stop() |
| 285 | |
| 286 | # Update parameters. |
| 287 | timers('optimizer').start() |
| 288 | optimizer.step() |
| 289 | timers('optimizer').stop() |
| 290 | |
| 291 | # Update learning rate. |
| 292 | skipped_iter = 0 |
| 293 | if not (args.fp16 and optimizer.overflow): |
| 294 | lr_scheduler.step() |
| 295 | else: |
| 296 | skipped_iter = 1 |
| 297 | |
| 298 | return lm_loss_reduced, nsp_loss_reduced, skipped_iter |
| 299 | |
| 300 | |
| 301 | def train(model, optimizer, lr_scheduler, |
no test coverage detected