(
opts, detector_url, detector_kwargs, rel_lo=0, rel_hi=1, batch_size: int=16,
batch_gen=None, jit=False, temporal_detector=False, num_video_frames: int=16,
feature_stats_cls=FeatureStats, subsample_factor: int=1, **stats_kwargs)
| 261 | |
| 262 | @torch.no_grad() |
| 263 | def compute_feature_stats_for_generator( |
| 264 | opts, detector_url, detector_kwargs, rel_lo=0, rel_hi=1, batch_size: int=16, |
| 265 | batch_gen=None, jit=False, temporal_detector=False, num_video_frames: int=16, |
| 266 | feature_stats_cls=FeatureStats, subsample_factor: int=1, **stats_kwargs): |
| 267 | |
| 268 | if batch_gen is None: |
| 269 | batch_gen = min(batch_size, 4) |
| 270 | assert batch_size % batch_gen == 0 |
| 271 | |
| 272 | # Setup generator and load labels. |
| 273 | G = copy.deepcopy(opts.G).eval().requires_grad_(False).to(opts.device) |
| 274 | dataset = dnnlib.util.construct_class_by_name(**opts.dataset_kwargs) |
| 275 | |
| 276 | # Image generation func. |
| 277 | def run_generator(z, c, t): |
| 278 | img = G(z=z, c=c, t=t, **opts.G_kwargs) |
| 279 | bt, c, h, w = img.shape |
| 280 | |
| 281 | if temporal_detector: |
| 282 | img = img.view(bt // num_video_frames, num_video_frames, c, h, w) # [batch_size, t, c, h, w] |
| 283 | img = img.permute(0, 2, 1, 3, 4).contiguous() # [batch_size, c, t, h, w] |
| 284 | |
| 285 | # img = torch.nn.functional.interpolate(img, size=(img.shape[2], 128, 128), mode='trilinear', align_corners=False) # downsample |
| 286 | # img = torch.nn.functional.interpolate(img, size=(img.shape[2], 256, 256), mode='trilinear', align_corners=False) # upsample |
| 287 | |
| 288 | img = (img * 127.5 + 128).clamp(0, 255).to(torch.uint8) |
| 289 | return img |
| 290 | |
| 291 | # JIT. |
| 292 | if jit: |
| 293 | z = torch.zeros([batch_gen, G.z_dim], device=opts.device) |
| 294 | c = torch.zeros([batch_gen, G.c_dim], device=opts.device) |
| 295 | t = torch.zeros([batch_gen, G.cfg.sampling.num_frames_per_video], device=opts.device) |
| 296 | run_generator = torch.jit.trace(run_generator, [z, c, t], check_trace=False) |
| 297 | |
| 298 | # Initialize. |
| 299 | stats = feature_stats_cls(**stats_kwargs) |
| 300 | assert stats.max_items is not None |
| 301 | progress = opts.progress.sub(tag='generator features', num_items=stats.max_items, rel_lo=rel_lo, rel_hi=rel_hi) |
| 302 | detector = get_feature_detector(url=detector_url, device=opts.device, num_gpus=opts.num_gpus, rank=opts.rank, verbose=progress.verbose) |
| 303 | |
| 304 | # Main loop. |
| 305 | while not stats.is_full(): |
| 306 | images = [] |
| 307 | for _i in range(batch_size // batch_gen): |
| 308 | z = torch.randn([batch_gen, G.z_dim], device=opts.device) |
| 309 | cond_sample_idx = [np.random.randint(len(dataset)) for _ in range(batch_gen)] |
| 310 | c = [dataset.get_label(i) for i in cond_sample_idx] |
| 311 | c = torch.from_numpy(np.stack(c)).pin_memory().to(opts.device) |
| 312 | t = [list(range(0, num_video_frames * subsample_factor, subsample_factor)) for _i in range(batch_gen)] |
| 313 | t = torch.from_numpy(np.stack(t)).pin_memory().to(opts.device) |
| 314 | images.append(run_generator(z, c, t)) |
| 315 | images = torch.cat(images) |
| 316 | if images.shape[1] == 1: |
| 317 | images = images.repeat([1, 3, *([1] * (images.ndim - 2))]) |
| 318 | features = detector(images, **detector_kwargs) |
| 319 | stats.append_torch(features, num_gpus=opts.num_gpus, rank=opts.rank) |
| 320 | progress.update(stats.num_items) |
nothing calls this directly
no test coverage detected