MCPcopy
hub / github.com/modelscope/FunASR / validate_epoch

Method validate_epoch

funasr/train_utils/trainer.py:542–654  ·  view source on GitHub ↗

Defines the validation process for a single epoch. Should be implemented with the actual model validation steps. Args: epoch (int): The current epoch number.

(
        self,
        model=None,
        dataloader_val=None,
        epoch=None,
        writer=None,
        **kwargs,
    )

Source from the content-addressed store, hash-verified

540 # iterator_stop = torch.tensor(0).to(self.device)
541
542 def validate_epoch(
543 self,
544 model=None,
545 dataloader_val=None,
546 epoch=None,
547 writer=None,
548 **kwargs,
549 ):
550 """
551 Defines the validation process for a single epoch.
552 Should be implemented with the actual model validation steps.
553
554 Args:
555 epoch (int): The current epoch number.
556 """
557 if self.use_ddp or self.use_fsdp:
558 dist.barrier()
559 logging.info(f"Validate epoch: {epoch}, rank: {self.rank}\n")
560 model.eval()
561
562 with torch.no_grad():
563
564 speed_stats = {}
565 time5 = time.perf_counter()
566 iterator_stop = torch.tensor(0).to(self.device)
567 dataloader_val.batch_sampler.set_epoch(epoch)
568 for batch_idx, batch in enumerate(dataloader_val):
569 if self.use_ddp or self.use_fsdp:
570 dist.all_reduce(iterator_stop, dist.ReduceOp.SUM)
571 if iterator_stop > 0:
572 break
573 time1 = time.perf_counter()
574 speed_stats["data_load"] = f"{time1 - time5:0.3f}"
575 batch = to_device(batch, self.device, non_blocking=True)
576
577 time2 = time.perf_counter()
578 retval = model(**batch)
579 time3 = time.perf_counter()
580 speed_stats["forward_time"] = f"{time3 - time2:0.3f}"
581 loss, stats, weight = retval
582 stats = {k: v for k, v in stats.items() if v is not None}
583
584 if self.use_ddp or self.use_fsdp:
585 # Apply weighted averaging for loss and stats
586 loss = (loss * weight.type(loss.dtype)).sum()
587 # if distributed, this method can also apply all_reduce()
588 # stats, weight = recursive_average(stats, weight, distributed=True)
589 if self.use_ddp or self.use_fsdp:
590 dist.all_reduce(weight, op=dist.ReduceOp.SUM)
591 # Now weight is summation over all workers
592 loss /= weight.sum() # shape:[1] -> shape:[]
593 # Multiply world_size because DistributedDataParallel
594 # automatically normalizes the gradient by world_size.
595 loss *= self.world_size
596
597 # Scale the loss since we're not updating for every mini-batch
598 loss = loss
599 time4 = time.perf_counter()

Callers 3

train_epochMethod · 0.95
mainFunction · 0.95
mainFunction · 0.95

Calls 5

logMethod · 0.95
to_deviceFunction · 0.90
evalMethod · 0.45
set_epochMethod · 0.45
trainMethod · 0.45

Tested by

no test coverage detected