(self, step, metrics, writing=True, training=False)
| 803 | # evaluate GAN using IS, FID, and Precision and recall. |
| 804 | # ----------------------------------------------------------------------------- |
| 805 | def evaluate(self, step, metrics, writing=True, training=False): |
| 806 | if self.global_rank == 0: |
| 807 | self.logger.info("Start Evaluation ({step} Step): {run_name}".format(step=step, run_name=self.run_name)) |
| 808 | if self.gen_ctlr.standing_statistics: |
| 809 | self.gen_ctlr.std_stat_counter += 1 |
| 810 | |
| 811 | is_best, num_splits, nearest_k = False, 1, 5 |
| 812 | is_acc = True if "ImageNet" in self.DATA.name and "Tiny" not in self.DATA.name else False |
| 813 | requires_grad = self.LOSS.apply_lo or self.RUN.langevin_sampling |
| 814 | with torch.no_grad() if not requires_grad else misc.dummy_context_mgr() as ctx: |
| 815 | misc.make_GAN_untrainable(self.Gen, self.Gen_ema, self.Dis) |
| 816 | generator, generator_mapping, generator_synthesis = self.gen_ctlr.prepare_generator() |
| 817 | metric_dict = {} |
| 818 | |
| 819 | fake_feats, fake_probs, fake_labels = features.generate_images_and_stack_features( |
| 820 | generator=generator, |
| 821 | discriminator=self.Dis, |
| 822 | eval_model=self.eval_model, |
| 823 | num_generate=self.num_eval[self.RUN.ref_dataset], |
| 824 | y_sampler="totally_random", |
| 825 | batch_size=self.OPTIMIZATION.batch_size, |
| 826 | z_prior=self.MODEL.z_prior, |
| 827 | truncation_factor=self.RUN.truncation_factor, |
| 828 | z_dim=self.MODEL.z_dim, |
| 829 | num_classes=self.DATA.num_classes, |
| 830 | LOSS=self.LOSS, |
| 831 | RUN=self.RUN, |
| 832 | MODEL=self.MODEL, |
| 833 | is_stylegan=self.is_stylegan, |
| 834 | generator_mapping=generator_mapping, |
| 835 | generator_synthesis=generator_synthesis, |
| 836 | quantize=True, |
| 837 | world_size=self.OPTIMIZATION.world_size, |
| 838 | DDP=self.DDP, |
| 839 | device=self.local_rank, |
| 840 | logger=self.logger, |
| 841 | disable_tqdm=self.global_rank != 0) |
| 842 | |
| 843 | if ("fid" in metrics or "prdc" in metrics) and self.global_rank == 0: |
| 844 | self.logger.info("{num_images} real images is used for evaluation.".format(num_images=len(self.eval_dataloader.dataset))) |
| 845 | |
| 846 | if "is" in metrics: |
| 847 | kl_score, kl_std, top1, top5 = ins.eval_features(probs=fake_probs, |
| 848 | labels=fake_labels, |
| 849 | data_loader=self.eval_dataloader, |
| 850 | num_features=self.num_eval[self.RUN.ref_dataset], |
| 851 | split=num_splits, |
| 852 | is_acc=is_acc, |
| 853 | is_torch_backbone=True if "torch" in self.RUN.eval_backbone else False) |
| 854 | if self.global_rank == 0: |
| 855 | self.logger.info("Inception score (Step: {step}, {num} generated images): {IS}".format( |
| 856 | step=step, num=str(self.num_eval[self.RUN.ref_dataset]), IS=kl_score)) |
| 857 | if is_acc: |
| 858 | self.logger.info("{eval_model} Top1 acc: (Step: {step}, {num} generated images): {Top1}".format( |
| 859 | eval_model=self.RUN.eval_backbone, step=step, num=str(self.num_eval[self.RUN.ref_dataset]), Top1=top1)) |
| 860 | self.logger.info("{eval_model} Top5 acc: (Step: {step}, {num} generated images): {Top5}".format( |
| 861 | eval_model=self.RUN.eval_backbone, step=step, num=str(self.num_eval[self.RUN.ref_dataset]), Top5=top5)) |
| 862 | metric_dict.update({"IS": kl_score, "Top1_acc": top1, "Top5_acc": top5}) |
no test coverage detected