| 18 | from camera_utils import LookAtPoseSampler |
| 19 | |
| 20 | def project( |
| 21 | G, |
| 22 | target: torch.Tensor, # [C,H,W] and dynamic range [0,255], W & H must match G output resolution |
| 23 | c: torch.Tensor, |
| 24 | *, |
| 25 | num_steps = 1000, |
| 26 | w_avg_samples = 10000, |
| 27 | initial_learning_rate = 0.1, |
| 28 | initial_noise_factor = 0.05, |
| 29 | lr_rampdown_length = 0.25, |
| 30 | lr_rampup_length = 0.05, |
| 31 | noise_ramp_length = 0.75, |
| 32 | regularize_noise_weight = 1e5, |
| 33 | optimize_noise = False, |
| 34 | verbose = False, |
| 35 | device: torch.device |
| 36 | ): |
| 37 | assert target.shape == (G.img_channels, G.img_resolution, G.img_resolution) |
| 38 | |
| 39 | def logprint(*args): |
| 40 | if verbose: |
| 41 | print(*args) |
| 42 | |
| 43 | G = copy.deepcopy(G).eval().requires_grad_(False).to(device) # type: ignore |
| 44 | |
| 45 | # Compute w stats. |
| 46 | logprint(f'Computing W midpoint and stddev using {w_avg_samples} samples...') |
| 47 | z_samples = np.random.RandomState(123).randn(w_avg_samples, G.z_dim) |
| 48 | camera_lookat_point = torch.tensor([0, 0, 0.0], device=device) |
| 49 | cam2world_pose = LookAtPoseSampler.sample(3.14/2, 3.14/2, camera_lookat_point, radius=2.7, device=device) |
| 50 | intrinsics = torch.tensor([[4.2647, 0, 0.5], [0, 4.2647, 0.5], [0, 0, 1]], device=device) |
| 51 | c_samples = torch.cat([cam2world_pose.reshape(-1, 16), intrinsics.reshape(-1, 9)], 1) |
| 52 | w_samples = G.mapping(torch.from_numpy(z_samples).to(device), c_samples.repeat(w_avg_samples,1)) # [N, L, C] |
| 53 | w_samples = w_samples[:, :1, :].cpu().numpy().astype(np.float32) # [N, 1, C] |
| 54 | w_avg = np.mean(w_samples, axis=0, keepdims=True) # [1, 1, C] |
| 55 | w_std = (np.sum((w_samples - w_avg) ** 2) / w_avg_samples) ** 0.5 |
| 56 | |
| 57 | # Setup noise inputs. |
| 58 | noise_bufs = { name: buf for (name, buf) in G.backbone.synthesis.named_buffers() if 'noise_const' in name } |
| 59 | |
| 60 | # Load VGG16 feature detector. |
| 61 | url = 'https://nvlabs-fi-cdn.nvidia.com/stylegan2-ada-pytorch/pretrained/metrics/vgg16.pt' |
| 62 | with dnnlib.util.open_url(url) as f: |
| 63 | vgg16 = torch.jit.load(f).eval().to(device) |
| 64 | |
| 65 | # Features for target image. |
| 66 | target_images = target.unsqueeze(0).to(device).to(torch.float32) / 255.0 * 2 - 1 |
| 67 | target_images_perc = (target_images + 1) * (255/2) |
| 68 | if target_images_perc.shape[2] > 256: |
| 69 | target_images_perc = F.interpolate(target_images_perc, size=(256, 256), mode='area') |
| 70 | target_features = vgg16(target_images_perc, resize_images=False, return_lpips=True) |
| 71 | |
| 72 | w_avg = torch.tensor(w_avg, dtype=torch.float32, device=device).repeat(1, G.backbone.mapping.num_ws, 1) |
| 73 | w_opt = w_avg.detach().clone() |
| 74 | w_opt.requires_grad = True |
| 75 | w_out = torch.zeros([num_steps] + list(w_opt.shape[1:]), dtype=torch.float32, device="cpu") |
| 76 | if optimize_noise: |
| 77 | optimizer = torch.optim.Adam([w_opt] + list(noise_bufs.values()), betas=(0.9, 0.999), lr=initial_learning_rate) |