MCPcopy
hub / github.com/Vchitect/Latte / compute_feature_stats_for_generator

Function compute_feature_stats_for_generator

tools/metrics/metric_utils.py:263–321  ·  view source on GitHub ↗
(
    opts, detector_url, detector_kwargs, rel_lo=0, rel_hi=1, batch_size: int=16,
    batch_gen=None, jit=False, temporal_detector=False, num_video_frames: int=16,
    feature_stats_cls=FeatureStats, subsample_factor: int=1, **stats_kwargs)

Source from the content-addressed store, hash-verified

261
262@torch.no_grad()
263def compute_feature_stats_for_generator(
264 opts, detector_url, detector_kwargs, rel_lo=0, rel_hi=1, batch_size: int=16,
265 batch_gen=None, jit=False, temporal_detector=False, num_video_frames: int=16,
266 feature_stats_cls=FeatureStats, subsample_factor: int=1, **stats_kwargs):
267
268 if batch_gen is None:
269 batch_gen = min(batch_size, 4)
270 assert batch_size % batch_gen == 0
271
272 # Setup generator and load labels.
273 G = copy.deepcopy(opts.G).eval().requires_grad_(False).to(opts.device)
274 dataset = dnnlib.util.construct_class_by_name(**opts.dataset_kwargs)
275
276 # Image generation func.
277 def run_generator(z, c, t):
278 img = G(z=z, c=c, t=t, **opts.G_kwargs)
279 bt, c, h, w = img.shape
280
281 if temporal_detector:
282 img = img.view(bt // num_video_frames, num_video_frames, c, h, w) # [batch_size, t, c, h, w]
283 img = img.permute(0, 2, 1, 3, 4).contiguous() # [batch_size, c, t, h, w]
284
285 # img = torch.nn.functional.interpolate(img, size=(img.shape[2], 128, 128), mode='trilinear', align_corners=False) # downsample
286 # img = torch.nn.functional.interpolate(img, size=(img.shape[2], 256, 256), mode='trilinear', align_corners=False) # upsample
287
288 img = (img * 127.5 + 128).clamp(0, 255).to(torch.uint8)
289 return img
290
291 # JIT.
292 if jit:
293 z = torch.zeros([batch_gen, G.z_dim], device=opts.device)
294 c = torch.zeros([batch_gen, G.c_dim], device=opts.device)
295 t = torch.zeros([batch_gen, G.cfg.sampling.num_frames_per_video], device=opts.device)
296 run_generator = torch.jit.trace(run_generator, [z, c, t], check_trace=False)
297
298 # Initialize.
299 stats = feature_stats_cls(**stats_kwargs)
300 assert stats.max_items is not None
301 progress = opts.progress.sub(tag='generator features', num_items=stats.max_items, rel_lo=rel_lo, rel_hi=rel_hi)
302 detector = get_feature_detector(url=detector_url, device=opts.device, num_gpus=opts.num_gpus, rank=opts.rank, verbose=progress.verbose)
303
304 # Main loop.
305 while not stats.is_full():
306 images = []
307 for _i in range(batch_size // batch_gen):
308 z = torch.randn([batch_gen, G.z_dim], device=opts.device)
309 cond_sample_idx = [np.random.randint(len(dataset)) for _ in range(batch_gen)]
310 c = [dataset.get_label(i) for i in cond_sample_idx]
311 c = torch.from_numpy(np.stack(c)).pin_memory().to(opts.device)
312 t = [list(range(0, num_video_frames * subsample_factor, subsample_factor)) for _i in range(batch_gen)]
313 t = torch.from_numpy(np.stack(t)).pin_memory().to(opts.device)
314 images.append(run_generator(z, c, t))
315 images = torch.cat(images)
316 if images.shape[1] == 1:
317 images = images.repeat([1, 3, *([1] * (images.ndim - 2))])
318 features = detector(images, **detector_kwargs)
319 stats.append_torch(features, num_gpus=opts.num_gpus, rank=opts.rank)
320 progress.update(stats.num_items)

Callers

nothing calls this directly

Calls 8

get_feature_detectorFunction · 0.85
run_generatorFunction · 0.85
subMethod · 0.80
is_fullMethod · 0.80
get_labelMethod · 0.80
appendMethod · 0.80
append_torchMethod · 0.80
updateMethod · 0.45

Tested by

no test coverage detected