MCPcopy Index your code
hub / github.com/LTH14/mar / Decoder

Class Decoder

models/vae.py:275–396  ·  view source on GitHub ↗

Source from the content-addressed store, hash-verified

273
274
275class Decoder(nn.Module):
276 def __init__(
277 self,
278 *,
279 ch=128,
280 out_ch=3,
281 ch_mult=(1, 1, 2, 2, 4),
282 num_res_blocks=2,
283 attn_resolutions=(),
284 dropout=0.0,
285 resamp_with_conv=True,
286 in_channels=3,
287 resolution=256,
288 z_channels=16,
289 give_pre_end=False,
290 **ignore_kwargs,
291 ):
292 super().__init__()
293 self.ch = ch
294 self.temb_ch = 0
295 self.num_resolutions = len(ch_mult)
296 self.num_res_blocks = num_res_blocks
297 self.resolution = resolution
298 self.in_channels = in_channels
299 self.give_pre_end = give_pre_end
300
301 # compute in_ch_mult, block_in and curr_res at lowest res
302 in_ch_mult = (1,) + tuple(ch_mult)
303 block_in = ch * ch_mult[self.num_resolutions - 1]
304 curr_res = resolution // 2 ** (self.num_resolutions - 1)
305 self.z_shape = (1, z_channels, curr_res, curr_res)
306 print(
307 "Working with z of shape {} = {} dimensions.".format(
308 self.z_shape, np.prod(self.z_shape)
309 )
310 )
311
312 # z to block_in
313 self.conv_in = torch.nn.Conv2d(
314 z_channels, block_in, kernel_size=3, stride=1, padding=1
315 )
316
317 # middle
318 self.mid = nn.Module()
319 self.mid.block_1 = ResnetBlock(
320 in_channels=block_in,
321 out_channels=block_in,
322 temb_channels=self.temb_ch,
323 dropout=dropout,
324 )
325 self.mid.attn_1 = AttnBlock(block_in)
326 self.mid.block_2 = ResnetBlock(
327 in_channels=block_in,
328 out_channels=block_in,
329 temb_channels=self.temb_ch,
330 dropout=dropout,
331 )
332

Callers 1

__init__Method · 0.85

Calls

no outgoing calls

Tested by

no test coverage detected