(self, x, t)
| 299 | padding=1) |
| 300 | |
| 301 | def forward(self, x, t): |
| 302 | assert x.shape[2] == x.shape[3] == self.resolution |
| 303 | |
| 304 | # timestep embedding |
| 305 | temb = get_timestep_embedding(t, self.ch) |
| 306 | temb = self.temb.dense[0](temb) |
| 307 | temb = nonlinearity(temb) |
| 308 | temb = self.temb.dense[1](temb) |
| 309 | |
| 310 | # downsampling |
| 311 | hs = [self.conv_in(x)] |
| 312 | for i_level in range(self.num_resolutions): |
| 313 | for i_block in range(self.num_res_blocks): |
| 314 | h = self.down[i_level].block[i_block](hs[-1], temb) |
| 315 | if len(self.down[i_level].attn) > 0: |
| 316 | h = self.down[i_level].attn[i_block](h) |
| 317 | hs.append(h) |
| 318 | if i_level != self.num_resolutions-1: |
| 319 | hs.append(self.down[i_level].downsample(hs[-1])) |
| 320 | |
| 321 | # middle |
| 322 | h = hs[-1] |
| 323 | h = self.mid.block_1(h, temb) |
| 324 | h = self.mid.attn_1(h) |
| 325 | h = self.mid.block_2(h, temb) |
| 326 | |
| 327 | # upsampling |
| 328 | for i_level in reversed(range(self.num_resolutions)): |
| 329 | for i_block in range(self.num_res_blocks+1): |
| 330 | h = self.up[i_level].block[i_block]( |
| 331 | torch.cat([h, hs.pop()], dim=1), temb) |
| 332 | if len(self.up[i_level].attn) > 0: |
| 333 | h = self.up[i_level].attn[i_block](h) |
| 334 | if i_level != 0: |
| 335 | h = self.up[i_level].upsample(h) |
| 336 | |
| 337 | # end |
| 338 | h = self.norm_out(h) |
| 339 | h = nonlinearity(h) |
| 340 | h = self.conv_out(h) |
| 341 | return h |
nothing calls this directly
no test coverage detected