Log training information such as losses, timing, ....
(
loss_dict,
total_loss_dict,
learning_rate,
iteration,
loss_scale,
report_memory_flag,
skipped_iter,
grad_norm,
params_norm,
num_zeros_in_grad,
model=None,
)
| 598 | |
| 599 | |
| 600 | def training_log( |
| 601 | loss_dict, |
| 602 | total_loss_dict, |
| 603 | learning_rate, |
| 604 | iteration, |
| 605 | loss_scale, |
| 606 | report_memory_flag, |
| 607 | skipped_iter, |
| 608 | grad_norm, |
| 609 | params_norm, |
| 610 | num_zeros_in_grad, |
| 611 | model=None, |
| 612 | ): |
| 613 | """Log training information such as losses, timing, ....""" |
| 614 | args = get_args() |
| 615 | timers = get_timers() |
| 616 | writer = get_tensorboard_writer() |
| 617 | |
| 618 | # Advanced, skipped, and Nan iterations. |
| 619 | advanced_iters_key = "advanced iterations" |
| 620 | skipped_iters_key = "skipped iterations" |
| 621 | nan_iters_key = "nan iterations" |
| 622 | # Advanced iterations. |
| 623 | if not skipped_iter: |
| 624 | total_loss_dict[advanced_iters_key] = ( |
| 625 | total_loss_dict.get(advanced_iters_key, 0) + 1 |
| 626 | ) |
| 627 | else: |
| 628 | if advanced_iters_key not in total_loss_dict: |
| 629 | total_loss_dict[advanced_iters_key] = 0 |
| 630 | # Skipped iterations. |
| 631 | total_loss_dict[skipped_iters_key] = ( |
| 632 | total_loss_dict.get(skipped_iters_key, 0) + skipped_iter |
| 633 | ) |
| 634 | # Update losses and set nan iterations |
| 635 | got_nan = False |
| 636 | for key in loss_dict: |
| 637 | if not skipped_iter: |
| 638 | total_loss_dict[key] = ( |
| 639 | total_loss_dict.get(key, torch.cuda.FloatTensor([0.0])) + loss_dict[key] |
| 640 | ) |
| 641 | else: |
| 642 | value = loss_dict[key].float().sum().item() |
| 643 | is_nan = value == float("inf") or value == -float("inf") or value != value |
| 644 | got_nan = got_nan or is_nan |
| 645 | total_loss_dict[nan_iters_key] = total_loss_dict.get(nan_iters_key, 0) + int( |
| 646 | got_nan |
| 647 | ) |
| 648 | |
| 649 | # Logging. |
| 650 | timers_to_log = [] |
| 651 | |
| 652 | def add_to_logging(name): |
| 653 | if name in timers.timers: |
| 654 | timers_to_log.append(name) |
| 655 | |
| 656 | add_to_logging("forward-compute") |
| 657 | add_to_logging("forward-recv") |
no test coverage detected