| 273 | |
| 274 | |
| 275 | class Decoder(nn.Module): |
| 276 | def __init__( |
| 277 | self, |
| 278 | *, |
| 279 | ch=128, |
| 280 | out_ch=3, |
| 281 | ch_mult=(1, 1, 2, 2, 4), |
| 282 | num_res_blocks=2, |
| 283 | attn_resolutions=(), |
| 284 | dropout=0.0, |
| 285 | resamp_with_conv=True, |
| 286 | in_channels=3, |
| 287 | resolution=256, |
| 288 | z_channels=16, |
| 289 | give_pre_end=False, |
| 290 | **ignore_kwargs, |
| 291 | ): |
| 292 | super().__init__() |
| 293 | self.ch = ch |
| 294 | self.temb_ch = 0 |
| 295 | self.num_resolutions = len(ch_mult) |
| 296 | self.num_res_blocks = num_res_blocks |
| 297 | self.resolution = resolution |
| 298 | self.in_channels = in_channels |
| 299 | self.give_pre_end = give_pre_end |
| 300 | |
| 301 | # compute in_ch_mult, block_in and curr_res at lowest res |
| 302 | in_ch_mult = (1,) + tuple(ch_mult) |
| 303 | block_in = ch * ch_mult[self.num_resolutions - 1] |
| 304 | curr_res = resolution // 2 ** (self.num_resolutions - 1) |
| 305 | self.z_shape = (1, z_channels, curr_res, curr_res) |
| 306 | print( |
| 307 | "Working with z of shape {} = {} dimensions.".format( |
| 308 | self.z_shape, np.prod(self.z_shape) |
| 309 | ) |
| 310 | ) |
| 311 | |
| 312 | # z to block_in |
| 313 | self.conv_in = torch.nn.Conv2d( |
| 314 | z_channels, block_in, kernel_size=3, stride=1, padding=1 |
| 315 | ) |
| 316 | |
| 317 | # middle |
| 318 | self.mid = nn.Module() |
| 319 | self.mid.block_1 = ResnetBlock( |
| 320 | in_channels=block_in, |
| 321 | out_channels=block_in, |
| 322 | temb_channels=self.temb_ch, |
| 323 | dropout=dropout, |
| 324 | ) |
| 325 | self.mid.attn_1 = AttnBlock(block_in) |
| 326 | self.mid.block_2 = ResnetBlock( |
| 327 | in_channels=block_in, |
| 328 | out_channels=block_in, |
| 329 | temb_channels=self.temb_ch, |
| 330 | dropout=dropout, |
| 331 | ) |
| 332 | |