Defines the validation process for a single epoch. Should be implemented with the actual model validation steps. Args: epoch (int): The current epoch number.
(
self,
model=None,
dataloader_val=None,
epoch=None,
writer=None,
**kwargs,
)
| 752 | optim.zero_grad(set_to_none=True) |
| 753 | |
| 754 | def validate_epoch( |
| 755 | self, |
| 756 | model=None, |
| 757 | dataloader_val=None, |
| 758 | epoch=None, |
| 759 | writer=None, |
| 760 | **kwargs, |
| 761 | ): |
| 762 | """ |
| 763 | Defines the validation process for a single epoch. |
| 764 | Should be implemented with the actual model validation steps. |
| 765 | |
| 766 | Args: |
| 767 | epoch (int): The current epoch number. |
| 768 | """ |
| 769 | self.val_loss_avg = 0.0 |
| 770 | self.val_acc_avg = 0.0 |
| 771 | |
| 772 | if self.use_ddp or self.use_fsdp or self.use_deepspeed: |
| 773 | dist.barrier() |
| 774 | logging.info(f"Validate epoch: {epoch}, rank: {self.rank}\n") |
| 775 | model.eval() |
| 776 | |
| 777 | with torch.no_grad(): |
| 778 | |
| 779 | speed_stats = {} |
| 780 | time_beg = time.perf_counter() |
| 781 | time5 = time_beg |
| 782 | |
| 783 | dataloader_val.batch_sampler.set_epoch(epoch) |
| 784 | for batch_idx, batch in enumerate(dataloader_val): |
| 785 | |
| 786 | loss_dict = { |
| 787 | "speed_stats": {}, |
| 788 | "epoch": epoch, |
| 789 | "batch_idx": batch_idx, |
| 790 | "data_split_i": kwargs.get("data_split_i", 0), |
| 791 | "data_split_num": kwargs.get("data_split_num", 1), |
| 792 | "log_step": batch_idx + kwargs.get("start_step", 0), |
| 793 | "batch_total": self.batch_total, |
| 794 | "step_in_epoch": batch_idx + 1, |
| 795 | "lr": 0.0, |
| 796 | } |
| 797 | |
| 798 | time1 = time.perf_counter() |
| 799 | loss_dict["speed_stats"]["data_load"] = f"{time1 - time_beg:0.3f}" |
| 800 | |
| 801 | batch = to_device(batch, self.device, non_blocking=True) |
| 802 | |
| 803 | time2 = time.perf_counter() |
| 804 | |
| 805 | self.forward_step(model, batch, loss_dict=loss_dict) |
| 806 | |
| 807 | time3 = time.perf_counter() |
| 808 | loss_dict["speed_stats"]["forward_time"] = f"{time3 - time2:0.3f}" |
| 809 | |
| 810 | total_time = f"{(time.perf_counter() - time5):0.3f}" |
| 811 | time5 = time.perf_counter() |
no test coverage detected