(opts, max_real, num_gen, nhood_size, row_batch_size, col_batch_size)
| 36 | #---------------------------------------------------------------------------- |
| 37 | |
| 38 | def compute_pr(opts, max_real, num_gen, nhood_size, row_batch_size, col_batch_size): |
| 39 | # detector_url = 'https://api.ngc.nvidia.com/v2/models/nvidia/research/stylegan3/versions/1/files/metrics/vgg16.pkl' |
| 40 | detector_url = 'file:///home/tiger/nfs/myenv/cache/useful_ckpts/vgg16.pkl' |
| 41 | detector_kwargs = dict(return_features=True) |
| 42 | |
| 43 | real_features = metric_utils.compute_feature_stats_for_dataset( |
| 44 | opts=opts, detector_url=detector_url, detector_kwargs=detector_kwargs, |
| 45 | rel_lo=0, rel_hi=0, capture_all=True, max_items=max_real).get_all_torch().to(torch.float16).to(opts.device) |
| 46 | |
| 47 | gen_features = metric_utils.compute_feature_stats_for_generator( |
| 48 | opts=opts, detector_url=detector_url, detector_kwargs=detector_kwargs, |
| 49 | rel_lo=0, rel_hi=1, capture_all=True, max_items=num_gen).get_all_torch().to(torch.float16).to(opts.device) |
| 50 | |
| 51 | results = dict() |
| 52 | for name, manifold, probes in [('precision', real_features, gen_features), ('recall', gen_features, real_features)]: |
| 53 | kth = [] |
| 54 | for manifold_batch in manifold.split(row_batch_size): |
| 55 | dist = compute_distances(row_features=manifold_batch, col_features=manifold, num_gpus=opts.num_gpus, rank=opts.rank, col_batch_size=col_batch_size) |
| 56 | kth.append(dist.to(torch.float32).kthvalue(nhood_size + 1).values.to(torch.float16) if opts.rank == 0 else None) |
| 57 | kth = torch.cat(kth) if opts.rank == 0 else None |
| 58 | pred = [] |
| 59 | for probes_batch in probes.split(row_batch_size): |
| 60 | dist = compute_distances(row_features=probes_batch, col_features=manifold, num_gpus=opts.num_gpus, rank=opts.rank, col_batch_size=col_batch_size) |
| 61 | pred.append((dist <= kth).any(dim=1) if opts.rank == 0 else None) |
| 62 | results[name] = float(torch.cat(pred).to(torch.float32).mean() if opts.rank == 0 else 'nan') |
| 63 | return results['precision'], results['recall'] |
| 64 | |
| 65 | #---------------------------------------------------------------------------- |
nothing calls this directly
no test coverage detected