MCPcopy
hub / github.com/hojonathanho/diffusion / _run_metrics

Method _run_metrics

diffusion_tf/tpu_utils/tpu_utils.py:343–390  ·  view source on GitHub ↗
(self, sess, ema: bool)

Source from the content-addressed store, hash-verified

341 return out
342
343 def _run_metrics(self, sess, ema: bool):
344 print('computing sample quality metrics...')
345 metrics = {}
346
347 # Get Inception activations on the real dataset
348 cached_inception_real_train, metrics['real_inception_score_train'] = self.inception_real_train.get(sess)
349 if self.inception_real_val is not None:
350 cached_inception_real_val, metrics['real_inception_score'] = self.inception_real_val.get(sess)
351 else:
352 cached_inception_real_val = None
353
354 # Generate lots of samples
355 num_inception_gen_batches = int(np.ceil(self.num_inception_samples / self.inception_bs))
356 print('generating {} samples and inception features ({} batches)...'.format(
357 self.num_inception_samples, num_inception_gen_batches))
358 inception_gen_batches = [
359 sess.run(self.ema_samples_inception if ema else self.samples_inception)
360 for _ in trange(num_inception_gen_batches, desc='sampling inception batch')
361 ]
362
363 # Compute FID and Inception score
364 assert set(self.samples_outputs.keys()) == set(inception_gen_batches[0].keys())
365 for samples_key in self.samples_outputs.keys():
366 # concat features from all batches into a single array
367 inception_gen = {
368 feat_key: np.concatenate(
369 [batch[samples_key][feat_key] for batch in inception_gen_batches], axis=0
370 )[:self.num_inception_samples].astype(np.float64)
371 for feat_key in ['pool_3', 'logits']
372 }
373 assert all(v.shape[0] == self.num_inception_samples for v in inception_gen.values())
374
375 # Inception score
376 metrics['{}/inception{}'.format(samples_key, self.num_inception_samples)] = float(
377 classifier_metrics_numpy.classifier_score_from_logits(inception_gen['logits']))
378
379 # FID vs training set
380 metrics['{}/trainfid{}'.format(samples_key, self.num_inception_samples)] = float(
381 classifier_metrics_numpy.frechet_classifier_distance_from_activations(
382 cached_inception_real_train['pool_3'], inception_gen['pool_3']))
383
384 # FID vs val set
385 if cached_inception_real_val is not None:
386 metrics['{}/fid{}'.format(samples_key, self.num_inception_samples)] = float(
387 classifier_metrics_numpy.frechet_classifier_distance_from_activations(
388 cached_inception_real_val['pool_3'], inception_gen['pool_3']))
389
390 return metrics
391
392 def _write_eval_and_samples(self, sess, log: utils.SummaryWriter, curr_step, prefix, ema: bool):
393 # Samples

Callers 1

Calls 2

getMethod · 0.80
runMethod · 0.45

Tested by

no test coverage detected