MCPcopy Index your code
hub / github.com/POSTECH-CVLab/PyTorch-StudioGAN / validate_classifier

Method validate_classifier

src/worker.py:1628–1668  ·  view source on GitHub ↗
(self,model, generator, generator_mapping, generator_synthesis, epoch, GAN_test, setting)

Source from the content-addressed store, hash-verified

1626 # validate GAN_train or GAN_test classifier using generated or training dataset
1627 # -----------------------------------------------------------------------------
1628 def validate_classifier(self,model, generator, generator_mapping, generator_synthesis, epoch, GAN_test, setting):
1629 model.eval()
1630 valid_top1_acc, valid_top5_acc, valid_loss = misc.AverageMeter(), misc.AverageMeter(), misc.AverageMeter()
1631 for i, (images, labels) in enumerate(self.train_dataloader):
1632 if GAN_test:
1633 images, labels, _, _, _, _, _ = sample.generate_images(z_prior=self.MODEL.z_prior,
1634 truncation_factor=self.RUN.truncation_factor,
1635 batch_size=self.OPTIMIZATION.batch_size,
1636 z_dim=self.MODEL.z_dim,
1637 num_classes=self.DATA.num_classes,
1638 y_sampler="totally_random",
1639 radius="N/A",
1640 generator=generator,
1641 discriminator=self.Dis,
1642 is_train=False,
1643 LOSS=self.LOSS,
1644 RUN=self.RUN,
1645 MODEL=self.MODEL,
1646 device=self.local_rank,
1647 is_stylegan=self.is_stylegan,
1648 generator_mapping=generator_mapping,
1649 generator_synthesis=generator_synthesis,
1650 style_mixing_p=0.0,
1651 stylegan_update_emas=False,
1652 cal_trsp_cost=False)
1653 else:
1654 images, labels = images.to(self.local_rank), labels.to(self.local_rank)
1655
1656 output = model(images)
1657 ce_loss = self.ce_loss(output, labels)
1658
1659 valid_acc1, valid_acc5 = misc.accuracy(output.data, labels, topk=(1, 5))
1660
1661 valid_loss.update(ce_loss.item(), images.size(0))
1662 valid_top1_acc.update(valid_acc1.item(), images.size(0))
1663 valid_top5_acc.update(valid_acc5.item(), images.size(0))
1664
1665 if self.local_rank == 0:
1666 self.logger.info("Top 1-acc {top1.val:.4f} ({top1.avg:.4f})\t"
1667 "Top 5-acc {top5.val:.4f} ({top5.avg:.4f})".format(top1=valid_top1_acc, top5=valid_top5_acc))
1668 return valid_top1_acc.avg, valid_top5_acc.avg, valid_loss.avg

Calls 2

evalMethod · 0.80
updateMethod · 0.45

Tested by

no test coverage detected