(self, embed_dim, ch_mult, use_variational=True, ckpt_path=None)
| 449 | |
| 450 | class AutoencoderKL(nn.Module): |
| 451 | def __init__(self, embed_dim, ch_mult, use_variational=True, ckpt_path=None): |
| 452 | super().__init__() |
| 453 | self.encoder = Encoder(ch_mult=ch_mult, z_channels=embed_dim) |
| 454 | self.decoder = Decoder(ch_mult=ch_mult, z_channels=embed_dim) |
| 455 | self.use_variational = use_variational |
| 456 | mult = 2 if self.use_variational else 1 |
| 457 | self.quant_conv = torch.nn.Conv2d(2 * embed_dim, mult * embed_dim, 1) |
| 458 | self.post_quant_conv = torch.nn.Conv2d(embed_dim, embed_dim, 1) |
| 459 | self.embed_dim = embed_dim |
| 460 | if ckpt_path is not None: |
| 461 | self.init_from_ckpt(ckpt_path) |
| 462 | |
| 463 | def init_from_ckpt(self, path): |
| 464 | sd = torch.load(path, map_location="cpu")["model"] |
nothing calls this directly
no test coverage detected