MCPcopy Index your code
hub / github.com/POSTECH-CVLab/PyTorch-StudioGAN / evaluate

Method evaluate

src/worker.py:805–935  ·  view source on GitHub ↗
(self, step, metrics, writing=True, training=False)

Source from the content-addressed store, hash-verified

803 # evaluate GAN using IS, FID, and Precision and recall.
804 # -----------------------------------------------------------------------------
805 def evaluate(self, step, metrics, writing=True, training=False):
806 if self.global_rank == 0:
807 self.logger.info("Start Evaluation ({step} Step): {run_name}".format(step=step, run_name=self.run_name))
808 if self.gen_ctlr.standing_statistics:
809 self.gen_ctlr.std_stat_counter += 1
810
811 is_best, num_splits, nearest_k = False, 1, 5
812 is_acc = True if "ImageNet" in self.DATA.name and "Tiny" not in self.DATA.name else False
813 requires_grad = self.LOSS.apply_lo or self.RUN.langevin_sampling
814 with torch.no_grad() if not requires_grad else misc.dummy_context_mgr() as ctx:
815 misc.make_GAN_untrainable(self.Gen, self.Gen_ema, self.Dis)
816 generator, generator_mapping, generator_synthesis = self.gen_ctlr.prepare_generator()
817 metric_dict = {}
818
819 fake_feats, fake_probs, fake_labels = features.generate_images_and_stack_features(
820 generator=generator,
821 discriminator=self.Dis,
822 eval_model=self.eval_model,
823 num_generate=self.num_eval[self.RUN.ref_dataset],
824 y_sampler="totally_random",
825 batch_size=self.OPTIMIZATION.batch_size,
826 z_prior=self.MODEL.z_prior,
827 truncation_factor=self.RUN.truncation_factor,
828 z_dim=self.MODEL.z_dim,
829 num_classes=self.DATA.num_classes,
830 LOSS=self.LOSS,
831 RUN=self.RUN,
832 MODEL=self.MODEL,
833 is_stylegan=self.is_stylegan,
834 generator_mapping=generator_mapping,
835 generator_synthesis=generator_synthesis,
836 quantize=True,
837 world_size=self.OPTIMIZATION.world_size,
838 DDP=self.DDP,
839 device=self.local_rank,
840 logger=self.logger,
841 disable_tqdm=self.global_rank != 0)
842
843 if ("fid" in metrics or "prdc" in metrics) and self.global_rank == 0:
844 self.logger.info("{num_images} real images is used for evaluation.".format(num_images=len(self.eval_dataloader.dataset)))
845
846 if "is" in metrics:
847 kl_score, kl_std, top1, top5 = ins.eval_features(probs=fake_probs,
848 labels=fake_labels,
849 data_loader=self.eval_dataloader,
850 num_features=self.num_eval[self.RUN.ref_dataset],
851 split=num_splits,
852 is_acc=is_acc,
853 is_torch_backbone=True if "torch" in self.RUN.eval_backbone else False)
854 if self.global_rank == 0:
855 self.logger.info("Inception score (Step: {step}, {num} generated images): {IS}".format(
856 step=step, num=str(self.num_eval[self.RUN.ref_dataset]), IS=kl_score))
857 if is_acc:
858 self.logger.info("{eval_model} Top1 acc: (Step: {step}, {num} generated images): {Top1}".format(
859 eval_model=self.RUN.eval_backbone, step=step, num=str(self.num_eval[self.RUN.ref_dataset]), Top1=top1))
860 self.logger.info("{eval_model} Top5 acc: (Step: {step}, {num} generated images): {Top5}".format(
861 eval_model=self.RUN.eval_backbone, step=step, num=str(self.num_eval[self.RUN.ref_dataset]), Top5=top5))
862 metric_dict.update({"IS": kl_score, "Top1_acc": top1, "Top5_acc": top5})

Callers 1

load_workerFunction · 0.95

Calls 2

prepare_generatorMethod · 0.80
updateMethod · 0.45

Tested by

no test coverage detected