| 159 | |
| 160 | |
| 161 | class Encoder(nn.Module): |
| 162 | def __init__( |
| 163 | self, |
| 164 | *, |
| 165 | ch=128, |
| 166 | out_ch=3, |
| 167 | ch_mult=(1, 1, 2, 2, 4), |
| 168 | num_res_blocks=2, |
| 169 | attn_resolutions=(16,), |
| 170 | dropout=0.0, |
| 171 | resamp_with_conv=True, |
| 172 | in_channels=3, |
| 173 | resolution=256, |
| 174 | z_channels=16, |
| 175 | double_z=True, |
| 176 | **ignore_kwargs, |
| 177 | ): |
| 178 | super().__init__() |
| 179 | self.ch = ch |
| 180 | self.temb_ch = 0 |
| 181 | self.num_resolutions = len(ch_mult) |
| 182 | self.num_res_blocks = num_res_blocks |
| 183 | self.resolution = resolution |
| 184 | self.in_channels = in_channels |
| 185 | |
| 186 | # downsampling |
| 187 | self.conv_in = torch.nn.Conv2d( |
| 188 | in_channels, self.ch, kernel_size=3, stride=1, padding=1 |
| 189 | ) |
| 190 | |
| 191 | curr_res = resolution |
| 192 | in_ch_mult = (1,) + tuple(ch_mult) |
| 193 | self.down = nn.ModuleList() |
| 194 | for i_level in range(self.num_resolutions): |
| 195 | block = nn.ModuleList() |
| 196 | attn = nn.ModuleList() |
| 197 | block_in = ch * in_ch_mult[i_level] |
| 198 | block_out = ch * ch_mult[i_level] |
| 199 | for i_block in range(self.num_res_blocks): |
| 200 | block.append( |
| 201 | ResnetBlock( |
| 202 | in_channels=block_in, |
| 203 | out_channels=block_out, |
| 204 | temb_channels=self.temb_ch, |
| 205 | dropout=dropout, |
| 206 | ) |
| 207 | ) |
| 208 | block_in = block_out |
| 209 | if curr_res in attn_resolutions: |
| 210 | attn.append(AttnBlock(block_in)) |
| 211 | down = nn.Module() |
| 212 | down.block = block |
| 213 | down.attn = attn |
| 214 | if i_level != self.num_resolutions - 1: |
| 215 | down.downsample = Downsample(block_in, resamp_with_conv) |
| 216 | curr_res = curr_res // 2 |
| 217 | self.down.append(down) |
| 218 | |