(self, *, ch, out_ch, ch_mult=(1, 2, 4, 8), num_res_blocks,
attn_resolutions, dropout=0.0, resamp_with_conv=True, in_channels,
resolution, z_channels, double_z=True, **kwargs)
| 153 | |
| 154 | class Encoder(nn.Module): |
| 155 | def __init__(self, *, ch, out_ch, ch_mult=(1, 2, 4, 8), num_res_blocks, |
| 156 | attn_resolutions, dropout=0.0, resamp_with_conv=True, in_channels, |
| 157 | resolution, z_channels, double_z=True, **kwargs): |
| 158 | super().__init__() |
| 159 | self.ch = ch |
| 160 | self.temb_ch = 0 |
| 161 | self.num_resolutions = len(ch_mult) |
| 162 | self.num_res_blocks = num_res_blocks |
| 163 | self.resolution = resolution |
| 164 | self.in_channels = in_channels |
| 165 | |
| 166 | # downsampling |
| 167 | self.conv_in = torch.nn.Conv2d(in_channels, |
| 168 | self.ch, |
| 169 | kernel_size=3, |
| 170 | stride=1, |
| 171 | padding=1) |
| 172 | |
| 173 | curr_res = resolution |
| 174 | in_ch_mult = (1,) + tuple(ch_mult) |
| 175 | self.in_ch_mult = in_ch_mult |
| 176 | self.down = nn.ModuleList() |
| 177 | for i_level in range(self.num_resolutions): |
| 178 | block = nn.ModuleList() |
| 179 | attn = nn.ModuleList() |
| 180 | block_in = ch * in_ch_mult[i_level] |
| 181 | block_out = ch * ch_mult[i_level] |
| 182 | for i_block in range(self.num_res_blocks): |
| 183 | block.append(ResnetBlock(in_channels=block_in, |
| 184 | out_channels=block_out, |
| 185 | temb_channels=self.temb_ch, |
| 186 | dropout=dropout)) |
| 187 | block_in = block_out |
| 188 | if curr_res in attn_resolutions: |
| 189 | attn.append(Attention(block_in)) |
| 190 | down = nn.Module() |
| 191 | down.block = block |
| 192 | down.attn = attn |
| 193 | if i_level != self.num_resolutions - 1: |
| 194 | down.downsample = EncoderDownSampleBlock(block_in, resamp_with_conv) |
| 195 | curr_res = curr_res // 2 |
| 196 | self.down.append(down) |
| 197 | |
| 198 | # middle |
| 199 | self.mid = nn.Module() |
| 200 | self.mid.block_1 = ResnetBlock(in_channels=block_in, |
| 201 | out_channels=block_in, |
| 202 | temb_channels=self.temb_ch, |
| 203 | dropout=dropout) |
| 204 | self.mid.attn_1 = Attention(block_in) |
| 205 | self.mid.block_2 = ResnetBlock(in_channels=block_in, |
| 206 | out_channels=block_in, |
| 207 | temb_channels=self.temb_ch, |
| 208 | dropout=dropout) |
| 209 | |
| 210 | # end |
| 211 | self.norm_out = GroupNorm(block_in) |
| 212 | self.conv_out = torch.nn.Conv2d(block_in, |
no test coverage detected