Train the model.
(model, optimizer, lr_scheduler,
train_data_iterator, val_data_iterator, timers, args)
| 389 | |
| 390 | |
| 391 | def train(model, optimizer, lr_scheduler, |
| 392 | train_data_iterator, val_data_iterator, timers, args): |
| 393 | """Train the model.""" |
| 394 | |
| 395 | # Turn on training mode which enables dropout. |
| 396 | model.train() |
| 397 | |
| 398 | # Tracking loss. |
| 399 | total_lm_loss = 0.0 |
| 400 | |
| 401 | # Iterations. |
| 402 | iteration = args.iteration |
| 403 | skipped_iters = 0 |
| 404 | |
| 405 | timers('interval time').start() |
| 406 | report_memory_flag = True |
| 407 | while iteration < args.train_iters: |
| 408 | |
| 409 | lm_loss, skipped_iter = train_step(train_data_iterator, |
| 410 | model, |
| 411 | optimizer, |
| 412 | lr_scheduler, |
| 413 | args, timers) |
| 414 | skipped_iters += skipped_iter |
| 415 | iteration += 1 |
| 416 | |
| 417 | # Update losses. |
| 418 | total_lm_loss += lm_loss.data.detach().float() |
| 419 | |
| 420 | # Logging. |
| 421 | if iteration % args.log_interval == 0: |
| 422 | learning_rate = optimizer.param_groups[0]['lr'] |
| 423 | avg_lm_loss = total_lm_loss.item() / args.log_interval |
| 424 | elapsed_time = timers('interval time').elapsed() |
| 425 | log_string = ' iteration {:8d}/{:8d} |'.format(iteration, |
| 426 | args.train_iters) |
| 427 | log_string += ' elapsed time per iteration (ms): {:.1f} |'.format( |
| 428 | elapsed_time * 1000.0 / args.log_interval) |
| 429 | log_string += ' learning rate {:.3E} |'.format(learning_rate) |
| 430 | log_string += ' lm loss {:.6E} |'.format(avg_lm_loss) |
| 431 | if args.fp16: |
| 432 | log_string += ' loss scale {:.1f} |'.format( |
| 433 | optimizer.cur_scale if args.deepspeed else optimizer.loss_scale) |
| 434 | print_rank_0(log_string) |
| 435 | total_lm_loss = 0.0 |
| 436 | if report_memory_flag: |
| 437 | report_memory('after {} iterations'.format(iteration)) |
| 438 | report_memory_flag = False |
| 439 | if USE_TORCH_DDP: |
| 440 | timers.log(['forward', 'backward', 'optimizer', |
| 441 | 'batch generator', 'data loader'], |
| 442 | normalizer=args.log_interval) |
| 443 | else: |
| 444 | timers.log(['forward', 'backward', 'allreduce', 'optimizer', |
| 445 | 'batch generator', 'data loader'], |
| 446 | normalizer=args.log_interval) |
| 447 | # Checkpointing |
| 448 | if args.save and args.save_interval and iteration % args.save_interval == 0: |
no test coverage detected