| 184 | |
| 185 | |
| 186 | def project_pti( |
| 187 | G, |
| 188 | target: torch.Tensor, # [C,H,W] and dynamic range [0,255], W & H must match G output resolution |
| 189 | w_pivot: torch.Tensor, |
| 190 | c: torch.Tensor, |
| 191 | *, |
| 192 | num_steps = 1000, |
| 193 | initial_learning_rate = 3e-4, |
| 194 | lr_rampdown_length = 0.25, |
| 195 | lr_rampup_length = 0.05, |
| 196 | verbose = False, |
| 197 | device: torch.device |
| 198 | ): |
| 199 | assert target.shape == (G.img_channels, G.img_resolution, G.img_resolution) |
| 200 | |
| 201 | def logprint(*args): |
| 202 | if verbose: |
| 203 | print(*args) |
| 204 | |
| 205 | G = copy.deepcopy(G).train().requires_grad_(True).to(device) # type: ignore |
| 206 | |
| 207 | # Load VGG16 feature detector. |
| 208 | url = 'https://nvlabs-fi-cdn.nvidia.com/stylegan2-ada-pytorch/pretrained/metrics/vgg16.pt' |
| 209 | with dnnlib.util.open_url(url) as f: |
| 210 | vgg16 = torch.jit.load(f).eval().to(device) |
| 211 | |
| 212 | # Features for target image. |
| 213 | target_images = target.unsqueeze(0).to(device).to(torch.float32) / 255.0 * 2 - 1 |
| 214 | target_images_perc = (target_images + 1) * (255/2) |
| 215 | if target_images_perc.shape[2] > 256: |
| 216 | target_images_perc = F.interpolate(target_images_perc, size=(256, 256), mode='area') |
| 217 | target_features = vgg16(target_images_perc, resize_images=False, return_lpips=True) |
| 218 | |
| 219 | w_pivot = w_pivot.to(device).detach() |
| 220 | optimizer = torch.optim.Adam(G.parameters(), betas=(0.9, 0.999), lr=initial_learning_rate) |
| 221 | |
| 222 | out_params = [] |
| 223 | |
| 224 | for step in range(num_steps): |
| 225 | # Learning rate schedule. |
| 226 | # t = step / num_steps |
| 227 | # lr_ramp = min(1.0, (1.0 - t) / lr_rampdown_length) |
| 228 | # lr_ramp = 0.5 - 0.5 * np.cos(lr_ramp * np.pi) |
| 229 | # lr_ramp = lr_ramp * min(1.0, t / lr_rampup_length) |
| 230 | # lr = initial_learning_rate * lr_ramp |
| 231 | # for param_group in optimizer.param_groups: |
| 232 | # param_group['lr'] = lr |
| 233 | |
| 234 | # Synth images from opt_w. |
| 235 | synth_images = G.synthesis(w_pivot, c=c, noise_mode='const')['image'] |
| 236 | |
| 237 | # Downsample image to 256x256 if it's larger than that. VGG was built for 224x224 images. |
| 238 | synth_images_perc = (synth_images + 1) * (255/2) |
| 239 | if synth_images_perc.shape[2] > 256: |
| 240 | synth_images_perc = F.interpolate(synth_images_perc, size=(256, 256), mode='area') |
| 241 | |
| 242 | # Features for synth images. |
| 243 | synth_features = vgg16(synth_images_perc, resize_images=False, return_lpips=True) |