| 176 | @torch.no_grad() |
| 177 | @torch.inference_mode() |
| 178 | def encode_vae_inpaint(vae, pixels, mask): |
| 179 | assert mask.ndim == 3 and pixels.ndim == 4 |
| 180 | assert mask.shape[-1] == pixels.shape[-2] |
| 181 | assert mask.shape[-2] == pixels.shape[-3] |
| 182 | |
| 183 | w = mask.round()[..., None] |
| 184 | pixels = pixels * (1 - w) + 0.5 * w |
| 185 | |
| 186 | latent = vae.encode(pixels) |
| 187 | B, C, H, W = latent.shape |
| 188 | |
| 189 | latent_mask = mask[:, None, :, :] |
| 190 | latent_mask = torch.nn.functional.interpolate(latent_mask, size=(H * 8, W * 8), mode="bilinear").round() |
| 191 | latent_mask = torch.nn.functional.max_pool2d(latent_mask, (8, 8)).round().to(latent) |
| 192 | |
| 193 | return latent, latent_mask |
| 194 | |
| 195 | |
| 196 | class VAEApprox(torch.nn.Module): |