Train the model.
(model, optimizer, lr_scheduler,
train_data_iterator, val_data_iterator, timers, args, summary_writer=None)
| 480 | |
| 481 | |
| 482 | def train(model, optimizer, lr_scheduler, |
| 483 | train_data_iterator, val_data_iterator, timers, args, summary_writer=None): |
| 484 | """Train the model.""" |
| 485 | # Turn on training mode which enables dropout. |
| 486 | model.train() |
| 487 | |
| 488 | # Tracking loss. |
| 489 | total_lm_loss = 0.0 |
| 490 | total_img_loss = total_txt_loss = 0.0 |
| 491 | |
| 492 | # Iterations. |
| 493 | skipped_iters = 0 |
| 494 | |
| 495 | timers('interval time').start() |
| 496 | report_memory_flag = True |
| 497 | mems = [] |
| 498 | while args.iteration < args.train_iters: |
| 499 | |
| 500 | if args.iteration % 100 == 0: |
| 501 | new_loaders = detect_new_datasets(args) |
| 502 | if new_loaders is not None: |
| 503 | print(f'Loatding new datasets ... Now we train models on {args.train_data}.') |
| 504 | train_data_iterator = iter(new_loaders[0]) |
| 505 | val_data_iterator = iter(new_loaders[1]) |
| 506 | # TODO close the original |
| 507 | |
| 508 | |
| 509 | lm_loss, skipped_iter, mems, img_loss, txt_loss = train_step(train_data_iterator, |
| 510 | model, |
| 511 | optimizer, |
| 512 | lr_scheduler, |
| 513 | args, timers, mems) |
| 514 | skipped_iters += skipped_iter |
| 515 | args.iteration += 1 |
| 516 | |
| 517 | # Update losses. |
| 518 | total_lm_loss += lm_loss.data.detach().float() |
| 519 | total_img_loss += img_loss.data.detach().float() |
| 520 | total_txt_loss += txt_loss.data.detach().float() |
| 521 | |
| 522 | # Logging. |
| 523 | if args.iteration % args.log_interval == 0: |
| 524 | learning_rate = optimizer.param_groups[0]['lr'] |
| 525 | avg_lm_loss = total_lm_loss.item() / args.log_interval |
| 526 | # average img & txt loss |
| 527 | avg_img_loss = total_img_loss.item() / args.log_interval |
| 528 | avg_txt_loss = total_txt_loss.item() / args.log_interval |
| 529 | |
| 530 | elapsed_time = timers('interval time').elapsed() |
| 531 | report_iteration_metrics(summary_writer, optimizer, learning_rate, avg_lm_loss, |
| 532 | elapsed_time * 1000.0 / args.log_interval, args.iteration, args.train_iters, args, |
| 533 | avg_img_loss, avg_txt_loss) |
| 534 | total_lm_loss = 0.0 |
| 535 | total_img_loss = 0.0 |
| 536 | total_txt_loss = 0.0 |
| 537 | if report_memory_flag: |
| 538 | report_memory('after {} iterations'.format(args.iteration)) |
| 539 | report_memory_flag = False |
no test coverage detected