(self,model, generator, generator_mapping, generator_synthesis, epoch, GAN_test, setting)
| 1626 | # validate GAN_train or GAN_test classifier using generated or training dataset |
| 1627 | # ----------------------------------------------------------------------------- |
| 1628 | def validate_classifier(self,model, generator, generator_mapping, generator_synthesis, epoch, GAN_test, setting): |
| 1629 | model.eval() |
| 1630 | valid_top1_acc, valid_top5_acc, valid_loss = misc.AverageMeter(), misc.AverageMeter(), misc.AverageMeter() |
| 1631 | for i, (images, labels) in enumerate(self.train_dataloader): |
| 1632 | if GAN_test: |
| 1633 | images, labels, _, _, _, _, _ = sample.generate_images(z_prior=self.MODEL.z_prior, |
| 1634 | truncation_factor=self.RUN.truncation_factor, |
| 1635 | batch_size=self.OPTIMIZATION.batch_size, |
| 1636 | z_dim=self.MODEL.z_dim, |
| 1637 | num_classes=self.DATA.num_classes, |
| 1638 | y_sampler="totally_random", |
| 1639 | radius="N/A", |
| 1640 | generator=generator, |
| 1641 | discriminator=self.Dis, |
| 1642 | is_train=False, |
| 1643 | LOSS=self.LOSS, |
| 1644 | RUN=self.RUN, |
| 1645 | MODEL=self.MODEL, |
| 1646 | device=self.local_rank, |
| 1647 | is_stylegan=self.is_stylegan, |
| 1648 | generator_mapping=generator_mapping, |
| 1649 | generator_synthesis=generator_synthesis, |
| 1650 | style_mixing_p=0.0, |
| 1651 | stylegan_update_emas=False, |
| 1652 | cal_trsp_cost=False) |
| 1653 | else: |
| 1654 | images, labels = images.to(self.local_rank), labels.to(self.local_rank) |
| 1655 | |
| 1656 | output = model(images) |
| 1657 | ce_loss = self.ce_loss(output, labels) |
| 1658 | |
| 1659 | valid_acc1, valid_acc5 = misc.accuracy(output.data, labels, topk=(1, 5)) |
| 1660 | |
| 1661 | valid_loss.update(ce_loss.item(), images.size(0)) |
| 1662 | valid_top1_acc.update(valid_acc1.item(), images.size(0)) |
| 1663 | valid_top5_acc.update(valid_acc5.item(), images.size(0)) |
| 1664 | |
| 1665 | if self.local_rank == 0: |
| 1666 | self.logger.info("Top 1-acc {top1.val:.4f} ({top1.avg:.4f})\t" |
| 1667 | "Top 5-acc {top5.val:.4f} ({top5.avg:.4f})".format(top1=valid_top1_acc, top5=valid_top5_acc)) |
| 1668 | return valid_top1_acc.avg, valid_top5_acc.avg, valid_loss.avg |
no test coverage detected