MCPcopy
hub / github.com/modelscope/FunASR / validate_epoch

Method validate_epoch

funasr/train_utils/trainer_ds.py:754–846  ·  view source on GitHub ↗

Defines the validation process for a single epoch. Should be implemented with the actual model validation steps. Args: epoch (int): The current epoch number.

(
        self,
        model=None,
        dataloader_val=None,
        epoch=None,
        writer=None,
        **kwargs,
    )

Source from the content-addressed store, hash-verified

752 optim.zero_grad(set_to_none=True)
753
754 def validate_epoch(
755 self,
756 model=None,
757 dataloader_val=None,
758 epoch=None,
759 writer=None,
760 **kwargs,
761 ):
762 """
763 Defines the validation process for a single epoch.
764 Should be implemented with the actual model validation steps.
765
766 Args:
767 epoch (int): The current epoch number.
768 """
769 self.val_loss_avg = 0.0
770 self.val_acc_avg = 0.0
771
772 if self.use_ddp or self.use_fsdp or self.use_deepspeed:
773 dist.barrier()
774 logging.info(f"Validate epoch: {epoch}, rank: {self.rank}\n")
775 model.eval()
776
777 with torch.no_grad():
778
779 speed_stats = {}
780 time_beg = time.perf_counter()
781 time5 = time_beg
782
783 dataloader_val.batch_sampler.set_epoch(epoch)
784 for batch_idx, batch in enumerate(dataloader_val):
785
786 loss_dict = {
787 "speed_stats": {},
788 "epoch": epoch,
789 "batch_idx": batch_idx,
790 "data_split_i": kwargs.get("data_split_i", 0),
791 "data_split_num": kwargs.get("data_split_num", 1),
792 "log_step": batch_idx + kwargs.get("start_step", 0),
793 "batch_total": self.batch_total,
794 "step_in_epoch": batch_idx + 1,
795 "lr": 0.0,
796 }
797
798 time1 = time.perf_counter()
799 loss_dict["speed_stats"]["data_load"] = f"{time1 - time_beg:0.3f}"
800
801 batch = to_device(batch, self.device, non_blocking=True)
802
803 time2 = time.perf_counter()
804
805 self.forward_step(model, batch, loss_dict=loss_dict)
806
807 time3 = time.perf_counter()
808 loss_dict["speed_stats"]["forward_time"] = f"{time3 - time2:0.3f}"
809
810 total_time = f"{(time.perf_counter() - time5):0.3f}"
811 time5 = time.perf_counter()

Callers 1

train_epochMethod · 0.95

Calls 6

forward_stepMethod · 0.95
logMethod · 0.95
to_deviceFunction · 0.90
evalMethod · 0.45
set_epochMethod · 0.45
trainMethod · 0.45

Tested by

no test coverage detected