MCPcopy
hub / github.com/yerfor/GeneFacePlusPlus / compute_pr

Function compute_pr

modules/eg3ds/metrics/precision_recall.py:38–63  ·  view source on GitHub ↗
(opts, max_real, num_gen, nhood_size, row_batch_size, col_batch_size)

Source from the content-addressed store, hash-verified

36#----------------------------------------------------------------------------
37
38def compute_pr(opts, max_real, num_gen, nhood_size, row_batch_size, col_batch_size):
39 # detector_url = 'https://api.ngc.nvidia.com/v2/models/nvidia/research/stylegan3/versions/1/files/metrics/vgg16.pkl'
40 detector_url = 'file:///home/tiger/nfs/myenv/cache/useful_ckpts/vgg16.pkl'
41 detector_kwargs = dict(return_features=True)
42
43 real_features = metric_utils.compute_feature_stats_for_dataset(
44 opts=opts, detector_url=detector_url, detector_kwargs=detector_kwargs,
45 rel_lo=0, rel_hi=0, capture_all=True, max_items=max_real).get_all_torch().to(torch.float16).to(opts.device)
46
47 gen_features = metric_utils.compute_feature_stats_for_generator(
48 opts=opts, detector_url=detector_url, detector_kwargs=detector_kwargs,
49 rel_lo=0, rel_hi=1, capture_all=True, max_items=num_gen).get_all_torch().to(torch.float16).to(opts.device)
50
51 results = dict()
52 for name, manifold, probes in [('precision', real_features, gen_features), ('recall', gen_features, real_features)]:
53 kth = []
54 for manifold_batch in manifold.split(row_batch_size):
55 dist = compute_distances(row_features=manifold_batch, col_features=manifold, num_gpus=opts.num_gpus, rank=opts.rank, col_batch_size=col_batch_size)
56 kth.append(dist.to(torch.float32).kthvalue(nhood_size + 1).values.to(torch.float16) if opts.rank == 0 else None)
57 kth = torch.cat(kth) if opts.rank == 0 else None
58 pred = []
59 for probes_batch in probes.split(row_batch_size):
60 dist = compute_distances(row_features=probes_batch, col_features=manifold, num_gpus=opts.num_gpus, rank=opts.rank, col_batch_size=col_batch_size)
61 pred.append((dist <= kth).any(dim=1) if opts.rank == 0 else None)
62 results[name] = float(torch.cat(pred).to(torch.float32).mean() if opts.rank == 0 else 'nan')
63 return results['precision'], results['recall']
64
65#----------------------------------------------------------------------------

Callers

nothing calls this directly

Calls 5

compute_distancesFunction · 0.85
get_all_torchMethod · 0.80
appendMethod · 0.80
meanMethod · 0.80
toMethod · 0.45

Tested by

no test coverage detected