| 48 | self.vgg16 = copy.deepcopy(vgg16) |
| 49 | |
| 50 | def forward(self, c): |
| 51 | # Generate random latents and interpolation t-values. |
| 52 | t = torch.rand([c.shape[0]], device=c.device) * (1 if self.sampling == 'full' else 0) |
| 53 | z0, z1 = torch.randn([c.shape[0] * 2, self.G.z_dim], device=c.device).chunk(2) |
| 54 | |
| 55 | # Interpolate in W or Z. |
| 56 | if self.space == 'w': |
| 57 | w0, w1 = self.G.mapping(z=torch.cat([z0,z1]), c=torch.cat([c,c])).chunk(2) |
| 58 | wt0 = w0.lerp(w1, t.unsqueeze(1).unsqueeze(2)) |
| 59 | wt1 = w0.lerp(w1, t.unsqueeze(1).unsqueeze(2) + self.epsilon) |
| 60 | else: # space == 'z' |
| 61 | zt0 = slerp(z0, z1, t.unsqueeze(1)) |
| 62 | zt1 = slerp(z0, z1, t.unsqueeze(1) + self.epsilon) |
| 63 | wt0, wt1 = self.G.mapping(z=torch.cat([zt0,zt1]), c=torch.cat([c,c])).chunk(2) |
| 64 | |
| 65 | # Randomize noise buffers. |
| 66 | for name, buf in self.G.named_buffers(): |
| 67 | if name.endswith('.noise_const'): |
| 68 | buf.copy_(torch.randn_like(buf)) |
| 69 | |
| 70 | # Generate images. |
| 71 | img = self.G.synthesis(ws=torch.cat([wt0,wt1]), noise_mode='const', force_fp32=True, **self.G_kwargs) |
| 72 | |
| 73 | # Center crop. |
| 74 | if self.crop: |
| 75 | assert img.shape[2] == img.shape[3] |
| 76 | c = img.shape[2] // 8 |
| 77 | img = img[:, :, c*3 : c*7, c*2 : c*6] |
| 78 | |
| 79 | # Downsample to 256x256. |
| 80 | factor = self.G.img_resolution // 256 |
| 81 | if factor > 1: |
| 82 | img = img.reshape([-1, img.shape[1], img.shape[2] // factor, factor, img.shape[3] // factor, factor]).mean([3, 5]) |
| 83 | |
| 84 | # Scale dynamic range from [-1,1] to [0,255]. |
| 85 | img = (img + 1) * (255 / 2) |
| 86 | if self.G.img_channels == 1: |
| 87 | img = img.repeat([1, 3, 1, 1]) |
| 88 | |
| 89 | # Evaluate differential LPIPS. |
| 90 | lpips_t0, lpips_t1 = self.vgg16(img, resize_images=False, return_lpips=True).chunk(2) |
| 91 | dist = (lpips_t0 - lpips_t1).square().sum(1) / self.epsilon ** 2 |
| 92 | return dist |
| 93 | |
| 94 | #---------------------------------------------------------------------------- |
| 95 | |