Defines the validation process for a single epoch. Should be implemented with the actual model validation steps. Args: epoch (int): The current epoch number.
(
self,
model=None,
dataloader_val=None,
epoch=None,
writer=None,
**kwargs,
)
| 540 | # iterator_stop = torch.tensor(0).to(self.device) |
| 541 | |
| 542 | def validate_epoch( |
| 543 | self, |
| 544 | model=None, |
| 545 | dataloader_val=None, |
| 546 | epoch=None, |
| 547 | writer=None, |
| 548 | **kwargs, |
| 549 | ): |
| 550 | """ |
| 551 | Defines the validation process for a single epoch. |
| 552 | Should be implemented with the actual model validation steps. |
| 553 | |
| 554 | Args: |
| 555 | epoch (int): The current epoch number. |
| 556 | """ |
| 557 | if self.use_ddp or self.use_fsdp: |
| 558 | dist.barrier() |
| 559 | logging.info(f"Validate epoch: {epoch}, rank: {self.rank}\n") |
| 560 | model.eval() |
| 561 | |
| 562 | with torch.no_grad(): |
| 563 | |
| 564 | speed_stats = {} |
| 565 | time5 = time.perf_counter() |
| 566 | iterator_stop = torch.tensor(0).to(self.device) |
| 567 | dataloader_val.batch_sampler.set_epoch(epoch) |
| 568 | for batch_idx, batch in enumerate(dataloader_val): |
| 569 | if self.use_ddp or self.use_fsdp: |
| 570 | dist.all_reduce(iterator_stop, dist.ReduceOp.SUM) |
| 571 | if iterator_stop > 0: |
| 572 | break |
| 573 | time1 = time.perf_counter() |
| 574 | speed_stats["data_load"] = f"{time1 - time5:0.3f}" |
| 575 | batch = to_device(batch, self.device, non_blocking=True) |
| 576 | |
| 577 | time2 = time.perf_counter() |
| 578 | retval = model(**batch) |
| 579 | time3 = time.perf_counter() |
| 580 | speed_stats["forward_time"] = f"{time3 - time2:0.3f}" |
| 581 | loss, stats, weight = retval |
| 582 | stats = {k: v for k, v in stats.items() if v is not None} |
| 583 | |
| 584 | if self.use_ddp or self.use_fsdp: |
| 585 | # Apply weighted averaging for loss and stats |
| 586 | loss = (loss * weight.type(loss.dtype)).sum() |
| 587 | # if distributed, this method can also apply all_reduce() |
| 588 | # stats, weight = recursive_average(stats, weight, distributed=True) |
| 589 | if self.use_ddp or self.use_fsdp: |
| 590 | dist.all_reduce(weight, op=dist.ReduceOp.SUM) |
| 591 | # Now weight is summation over all workers |
| 592 | loss /= weight.sum() # shape:[1] -> shape:[] |
| 593 | # Multiply world_size because DistributedDataParallel |
| 594 | # automatically normalizes the gradient by world_size. |
| 595 | loss *= self.world_size |
| 596 | |
| 597 | # Scale the loss since we're not updating for every mini-batch |
| 598 | loss = loss |
| 599 | time4 = time.perf_counter() |