| 190 | |
| 191 | |
| 192 | class Model(nn.Module): |
| 193 | def __init__(self, config): |
| 194 | super().__init__() |
| 195 | self.config = config |
| 196 | ch, out_ch, ch_mult = config.model.ch, config.model.out_ch, tuple(config.model.ch_mult) |
| 197 | num_res_blocks = config.model.num_res_blocks |
| 198 | attn_resolutions = config.model.attn_resolutions |
| 199 | dropout = config.model.dropout |
| 200 | in_channels = config.model.in_channels |
| 201 | resolution = config.data.image_size |
| 202 | resamp_with_conv = config.model.resamp_with_conv |
| 203 | num_timesteps = config.diffusion.num_diffusion_timesteps |
| 204 | |
| 205 | if config.model.type == 'bayesian': |
| 206 | self.logvar = nn.Parameter(torch.zeros(num_timesteps)) |
| 207 | |
| 208 | self.ch = ch |
| 209 | self.temb_ch = self.ch*4 |
| 210 | self.num_resolutions = len(ch_mult) |
| 211 | self.num_res_blocks = num_res_blocks |
| 212 | self.resolution = resolution |
| 213 | self.in_channels = in_channels |
| 214 | |
| 215 | # timestep embedding |
| 216 | self.temb = nn.Module() |
| 217 | self.temb.dense = nn.ModuleList([ |
| 218 | torch.nn.Linear(self.ch, |
| 219 | self.temb_ch), |
| 220 | torch.nn.Linear(self.temb_ch, |
| 221 | self.temb_ch), |
| 222 | ]) |
| 223 | |
| 224 | # downsampling |
| 225 | self.conv_in = torch.nn.Conv2d(in_channels, |
| 226 | self.ch, |
| 227 | kernel_size=3, |
| 228 | stride=1, |
| 229 | padding=1) |
| 230 | |
| 231 | curr_res = resolution |
| 232 | in_ch_mult = (1,)+ch_mult |
| 233 | self.down = nn.ModuleList() |
| 234 | block_in = None |
| 235 | for i_level in range(self.num_resolutions): |
| 236 | block = nn.ModuleList() |
| 237 | attn = nn.ModuleList() |
| 238 | block_in = ch*in_ch_mult[i_level] |
| 239 | block_out = ch*ch_mult[i_level] |
| 240 | for i_block in range(self.num_res_blocks): |
| 241 | block.append(ResnetBlock(in_channels=block_in, |
| 242 | out_channels=block_out, |
| 243 | temb_channels=self.temb_ch, |
| 244 | dropout=dropout)) |
| 245 | block_in = block_out |
| 246 | if curr_res in attn_resolutions: |
| 247 | attn.append(AttnBlock(block_in)) |
| 248 | down = nn.Module() |
| 249 | down.block = block |