MCPcopy
hub / github.com/ermongroup/ddim / forward

Method forward

models/diffusion.py:301–341  ·  view source on GitHub ↗
(self, x, t)

Source from the content-addressed store, hash-verified

299 padding=1)
300
301 def forward(self, x, t):
302 assert x.shape[2] == x.shape[3] == self.resolution
303
304 # timestep embedding
305 temb = get_timestep_embedding(t, self.ch)
306 temb = self.temb.dense[0](temb)
307 temb = nonlinearity(temb)
308 temb = self.temb.dense[1](temb)
309
310 # downsampling
311 hs = [self.conv_in(x)]
312 for i_level in range(self.num_resolutions):
313 for i_block in range(self.num_res_blocks):
314 h = self.down[i_level].block[i_block](hs[-1], temb)
315 if len(self.down[i_level].attn) > 0:
316 h = self.down[i_level].attn[i_block](h)
317 hs.append(h)
318 if i_level != self.num_resolutions-1:
319 hs.append(self.down[i_level].downsample(hs[-1]))
320
321 # middle
322 h = hs[-1]
323 h = self.mid.block_1(h, temb)
324 h = self.mid.attn_1(h)
325 h = self.mid.block_2(h, temb)
326
327 # upsampling
328 for i_level in reversed(range(self.num_resolutions)):
329 for i_block in range(self.num_res_blocks+1):
330 h = self.up[i_level].block[i_block](
331 torch.cat([h, hs.pop()], dim=1), temb)
332 if len(self.up[i_level].attn) > 0:
333 h = self.up[i_level].attn[i_block](h)
334 if i_level != 0:
335 h = self.up[i_level].upsample(h)
336
337 # end
338 h = self.norm_out(h)
339 h = nonlinearity(h)
340 h = self.conv_out(h)
341 return h

Callers

nothing calls this directly

Calls 2

get_timestep_embeddingFunction · 0.85
nonlinearityFunction · 0.85

Tested by

no test coverage detected