Train the model.
(model, optimizer, lr_scheduler,
train_data_iterator, val_data_iterator, timers, args)
| 299 | |
| 300 | |
| 301 | def train(model, optimizer, lr_scheduler, |
| 302 | train_data_iterator, val_data_iterator, timers, args): |
| 303 | """Train the model.""" |
| 304 | |
| 305 | # Turn on training mode which enables dropout. |
| 306 | model.train() |
| 307 | |
| 308 | # Tracking loss. |
| 309 | total_lm_loss = 0.0 |
| 310 | total_nsp_loss = 0.0 |
| 311 | |
| 312 | # Iterations. |
| 313 | iteration = args.iteration |
| 314 | skipped_iters = 0 |
| 315 | |
| 316 | timers('interval time').start() |
| 317 | report_memory_flag = True |
| 318 | while iteration < args.train_iters: |
| 319 | |
| 320 | lm_loss, nsp_loss, skipped_iter = train_step(train_data_iterator, |
| 321 | model, |
| 322 | optimizer, |
| 323 | lr_scheduler, |
| 324 | args, timers) |
| 325 | skipped_iters += skipped_iter |
| 326 | iteration += 1 |
| 327 | |
| 328 | # Update losses. |
| 329 | total_lm_loss += lm_loss.data.detach().float() |
| 330 | total_nsp_loss += nsp_loss.data.detach().float() |
| 331 | |
| 332 | # Logging. |
| 333 | if iteration % args.log_interval == 0: |
| 334 | learning_rate = optimizer.param_groups[0]['lr'] |
| 335 | avg_nsp_loss = total_nsp_loss.item() / args.log_interval |
| 336 | avg_lm_loss = total_lm_loss.item() / args.log_interval |
| 337 | elapsed_time = timers('interval time').elapsed() |
| 338 | log_string = ' iteration {:8d}/{:8d} |'.format(iteration, |
| 339 | args.train_iters) |
| 340 | log_string += ' elapsed time per iteration (ms): {:.1f} |'.format( |
| 341 | elapsed_time * 1000.0 / args.log_interval) |
| 342 | log_string += ' learning rate {:.3E} |'.format(learning_rate) |
| 343 | log_string += ' lm loss {:.6E} |'.format(avg_lm_loss) |
| 344 | log_string += ' nsp loss {:.6E} |'.format(avg_nsp_loss) |
| 345 | if args.fp16: |
| 346 | log_string += ' loss scale {:.1f} |'.format( |
| 347 | optimizer.loss_scale) |
| 348 | print_rank_0(log_string) |
| 349 | total_nsp_loss = 0.0 |
| 350 | total_lm_loss = 0.0 |
| 351 | if report_memory_flag: |
| 352 | report_memory('after {} iterations'.format(iteration)) |
| 353 | report_memory_flag = False |
| 354 | timers.log(['forward', 'backward', 'optimizer', 'batch generator', |
| 355 | 'data loader'], |
| 356 | normalizer=args.log_interval) |
| 357 | # Checkpointing |
| 358 | if args.save and args.save_interval and iteration % args.save_interval == 0: |
no test coverage detected