MCPcopy Index your code
hub / github.com/LTH14/mar / forward

Method forward

models/vae.py:245–272  ·  view source on GitHub ↗
(self, x)

Source from the content-addressed store, hash-verified

243 )
244
245 def forward(self, x):
246 # assert x.shape[2] == x.shape[3] == self.resolution, "{}, {}, {}".format(x.shape[2], x.shape[3], self.resolution)
247
248 # timestep embedding
249 temb = None
250
251 # downsampling
252 hs = [self.conv_in(x)]
253 for i_level in range(self.num_resolutions):
254 for i_block in range(self.num_res_blocks):
255 h = self.down[i_level].block[i_block](hs[-1], temb)
256 if len(self.down[i_level].attn) > 0:
257 h = self.down[i_level].attn[i_block](h)
258 hs.append(h)
259 if i_level != self.num_resolutions - 1:
260 hs.append(self.down[i_level].downsample(hs[-1]))
261
262 # middle
263 h = hs[-1]
264 h = self.mid.block_1(h, temb)
265 h = self.mid.attn_1(h)
266 h = self.mid.block_2(h, temb)
267
268 # end
269 h = self.norm_out(h)
270 h = nonlinearity(h)
271 h = self.conv_out(h)
272 return h
273
274
275class Decoder(nn.Module):

Callers

nothing calls this directly

Calls 1

nonlinearityFunction · 0.85

Tested by

no test coverage detected