| 16 | #---------------------------------------------------------------------------- |
| 17 | |
| 18 | def compute_is(opts, num_gen, num_splits): |
| 19 | # Direct TorchScript translation of http://download.tensorflow.org/models/image/imagenet/inception-2015-12-05.tgz |
| 20 | detector_url = 'https://nvlabs-fi-cdn.nvidia.com/stylegan2-ada-pytorch/pretrained/metrics/inception-2015-12-05.pt' |
| 21 | detector_kwargs = dict(no_output_bias=True) # Match the original implementation by not applying bias in the softmax layer. |
| 22 | |
| 23 | if opts.generator_as_dataset: |
| 24 | compute_gen_stats_fn = metric_utils.compute_feature_stats_for_dataset |
| 25 | gen_opts = metric_utils.rewrite_opts_for_gen_dataset(opts) |
| 26 | gen_kwargs = dict(use_image_dataset=True) |
| 27 | else: |
| 28 | compute_gen_stats_fn = metric_utils.compute_feature_stats_for_generator |
| 29 | gen_opts = opts |
| 30 | gen_kwargs = dict() |
| 31 | |
| 32 | gen_probs = compute_gen_stats_fn( |
| 33 | opts=gen_opts, detector_url=detector_url, detector_kwargs=detector_kwargs, |
| 34 | capture_all=True, max_items=num_gen, **gen_kwargs).get_all() |
| 35 | |
| 36 | if opts.rank != 0: |
| 37 | return float('nan'), float('nan') |
| 38 | |
| 39 | scores = [] |
| 40 | for i in range(num_splits): |
| 41 | part = gen_probs[i * num_gen // num_splits : (i + 1) * num_gen // num_splits] |
| 42 | kl = part * (np.log(part) - np.log(np.mean(part, axis=0, keepdims=True))) |
| 43 | kl = np.mean(np.sum(kl, axis=1)) |
| 44 | print(kl) |
| 45 | scores.append(np.exp(kl)) |
| 46 | return float(np.mean(scores)), float(np.std(scores)) |
| 47 | |
| 48 | #---------------------------------------------------------------------------- |