MCPcopy
hub / github.com/lllyasviel/Fooocus / encode_vae_inpaint

Function encode_vae_inpaint

modules/core.py:178–193  ·  view source on GitHub ↗
(vae, pixels, mask)

Source from the content-addressed store, hash-verified

176@torch.no_grad()
177@torch.inference_mode()
178def encode_vae_inpaint(vae, pixels, mask):
179 assert mask.ndim == 3 and pixels.ndim == 4
180 assert mask.shape[-1] == pixels.shape[-2]
181 assert mask.shape[-2] == pixels.shape[-3]
182
183 w = mask.round()[..., None]
184 pixels = pixels * (1 - w) + 0.5 * w
185
186 latent = vae.encode(pixels)
187 B, C, H, W = latent.shape
188
189 latent_mask = mask[:, None, :, :]
190 latent_mask = torch.nn.functional.interpolate(latent_mask, size=(H * 8, W * 8), mode="bilinear").round()
191 latent_mask = torch.nn.functional.max_pool2d(latent_mask, (8, 8)).round().to(latent)
192
193 return latent, latent_mask
194
195
196class VAEApprox(torch.nn.Module):

Callers

nothing calls this directly

Calls 2

toMethod · 0.80
encodeMethod · 0.45

Tested by

no test coverage detected