(self, x)
| 243 | ) |
| 244 | |
| 245 | def forward(self, x): |
| 246 | # assert x.shape[2] == x.shape[3] == self.resolution, "{}, {}, {}".format(x.shape[2], x.shape[3], self.resolution) |
| 247 | |
| 248 | # timestep embedding |
| 249 | temb = None |
| 250 | |
| 251 | # downsampling |
| 252 | hs = [self.conv_in(x)] |
| 253 | for i_level in range(self.num_resolutions): |
| 254 | for i_block in range(self.num_res_blocks): |
| 255 | h = self.down[i_level].block[i_block](hs[-1], temb) |
| 256 | if len(self.down[i_level].attn) > 0: |
| 257 | h = self.down[i_level].attn[i_block](h) |
| 258 | hs.append(h) |
| 259 | if i_level != self.num_resolutions - 1: |
| 260 | hs.append(self.down[i_level].downsample(hs[-1])) |
| 261 | |
| 262 | # middle |
| 263 | h = hs[-1] |
| 264 | h = self.mid.block_1(h, temb) |
| 265 | h = self.mid.attn_1(h) |
| 266 | h = self.mid.block_2(h, temb) |
| 267 | |
| 268 | # end |
| 269 | h = self.norm_out(h) |
| 270 | h = nonlinearity(h) |
| 271 | h = self.conv_out(h) |
| 272 | return h |
| 273 | |
| 274 | |
| 275 | class Decoder(nn.Module): |
nothing calls this directly
no test coverage detected