| 246 | |
| 247 | |
| 248 | class TransparentVAEDecoder: |
| 249 | def __init__(self, sd, device, dtype): |
| 250 | self.load_device = device |
| 251 | self.dtype = dtype |
| 252 | |
| 253 | model = UNet1024(in_channels=3, out_channels=4) |
| 254 | model.load_state_dict(sd, strict=True) |
| 255 | model.to(self.load_device, dtype=self.dtype) |
| 256 | model.eval() |
| 257 | self.model = model |
| 258 | |
| 259 | @torch.no_grad() |
| 260 | def estimate_single_pass(self, pixel, latent): |
| 261 | y = self.model(pixel, latent) |
| 262 | return y |
| 263 | |
| 264 | @torch.no_grad() |
| 265 | def estimate_augmented(self, pixel, latent): |
| 266 | args = [ |
| 267 | [False, 0], |
| 268 | [False, 1], |
| 269 | [False, 2], |
| 270 | [False, 3], |
| 271 | [True, 0], |
| 272 | [True, 1], |
| 273 | [True, 2], |
| 274 | [True, 3], |
| 275 | ] |
| 276 | |
| 277 | result = [] |
| 278 | |
| 279 | for flip, rok in tqdm(args): |
| 280 | feed_pixel = pixel.clone() |
| 281 | feed_latent = latent.clone() |
| 282 | |
| 283 | if flip: |
| 284 | feed_pixel = torch.flip(feed_pixel, dims=(3,)) |
| 285 | feed_latent = torch.flip(feed_latent, dims=(3,)) |
| 286 | |
| 287 | feed_pixel = torch.rot90(feed_pixel, k=rok, dims=(2, 3)) |
| 288 | feed_latent = torch.rot90(feed_latent, k=rok, dims=(2, 3)) |
| 289 | |
| 290 | eps = self.estimate_single_pass(feed_pixel, feed_latent).clip(0, 1) |
| 291 | eps = torch.rot90(eps, k=-rok, dims=(2, 3)) |
| 292 | |
| 293 | if flip: |
| 294 | eps = torch.flip(eps, dims=(3,)) |
| 295 | |
| 296 | result += [eps] |
| 297 | |
| 298 | result = torch.stack(result, dim=0) |
| 299 | median = torch.median(result, dim=0).values |
| 300 | return median |
| 301 | |
| 302 | @torch.no_grad() |
| 303 | def decode_pixel( |
| 304 | self, pixel: torch.TensorType, latent: torch.TensorType |
| 305 | ) -> torch.TensorType: |