(self, sess, ema: bool)
| 341 | return out |
| 342 | |
| 343 | def _run_metrics(self, sess, ema: bool): |
| 344 | print('computing sample quality metrics...') |
| 345 | metrics = {} |
| 346 | |
| 347 | # Get Inception activations on the real dataset |
| 348 | cached_inception_real_train, metrics['real_inception_score_train'] = self.inception_real_train.get(sess) |
| 349 | if self.inception_real_val is not None: |
| 350 | cached_inception_real_val, metrics['real_inception_score'] = self.inception_real_val.get(sess) |
| 351 | else: |
| 352 | cached_inception_real_val = None |
| 353 | |
| 354 | # Generate lots of samples |
| 355 | num_inception_gen_batches = int(np.ceil(self.num_inception_samples / self.inception_bs)) |
| 356 | print('generating {} samples and inception features ({} batches)...'.format( |
| 357 | self.num_inception_samples, num_inception_gen_batches)) |
| 358 | inception_gen_batches = [ |
| 359 | sess.run(self.ema_samples_inception if ema else self.samples_inception) |
| 360 | for _ in trange(num_inception_gen_batches, desc='sampling inception batch') |
| 361 | ] |
| 362 | |
| 363 | # Compute FID and Inception score |
| 364 | assert set(self.samples_outputs.keys()) == set(inception_gen_batches[0].keys()) |
| 365 | for samples_key in self.samples_outputs.keys(): |
| 366 | # concat features from all batches into a single array |
| 367 | inception_gen = { |
| 368 | feat_key: np.concatenate( |
| 369 | [batch[samples_key][feat_key] for batch in inception_gen_batches], axis=0 |
| 370 | )[:self.num_inception_samples].astype(np.float64) |
| 371 | for feat_key in ['pool_3', 'logits'] |
| 372 | } |
| 373 | assert all(v.shape[0] == self.num_inception_samples for v in inception_gen.values()) |
| 374 | |
| 375 | # Inception score |
| 376 | metrics['{}/inception{}'.format(samples_key, self.num_inception_samples)] = float( |
| 377 | classifier_metrics_numpy.classifier_score_from_logits(inception_gen['logits'])) |
| 378 | |
| 379 | # FID vs training set |
| 380 | metrics['{}/trainfid{}'.format(samples_key, self.num_inception_samples)] = float( |
| 381 | classifier_metrics_numpy.frechet_classifier_distance_from_activations( |
| 382 | cached_inception_real_train['pool_3'], inception_gen['pool_3'])) |
| 383 | |
| 384 | # FID vs val set |
| 385 | if cached_inception_real_val is not None: |
| 386 | metrics['{}/fid{}'.format(samples_key, self.num_inception_samples)] = float( |
| 387 | classifier_metrics_numpy.frechet_classifier_distance_from_activations( |
| 388 | cached_inception_real_val['pool_3'], inception_gen['pool_3'])) |
| 389 | |
| 390 | return metrics |
| 391 | |
| 392 | def _write_eval_and_samples(self, sess, log: utils.SummaryWriter, curr_step, prefix, ema: bool): |
| 393 | # Samples |
no test coverage detected