The full UNet model with attention and timestep embedding. :param in_channels: channels in the input Tensor. :param model_channels: base channel count for the model. :param out_channels: channels in the output Tensor. :param num_res_blocks: number of residual blocks per downsam
| 276 | |
| 277 | |
| 278 | class UNetModel(nn.Module): |
| 279 | """ |
| 280 | The full UNet model with attention and timestep embedding. |
| 281 | |
| 282 | :param in_channels: channels in the input Tensor. |
| 283 | :param model_channels: base channel count for the model. |
| 284 | :param out_channels: channels in the output Tensor. |
| 285 | :param num_res_blocks: number of residual blocks per downsample. |
| 286 | :param attention_resolutions: a collection of downsample rates at which |
| 287 | attention will take place. May be a set, list, or tuple. |
| 288 | For example, if this contains 4, then at 4x downsampling, attention |
| 289 | will be used. |
| 290 | :param dropout: the dropout probability. |
| 291 | :param channel_mult: channel multiplier for each level of the UNet. |
| 292 | :param conv_resample: if True, use learned convolutions for upsampling and |
| 293 | downsampling. |
| 294 | :param dims: determines if the signal is 1D, 2D, or 3D. |
| 295 | :param num_classes: if specified (as an int), then this model will be |
| 296 | class-conditional with `num_classes` classes. |
| 297 | :param use_checkpoint: use gradient checkpointing to reduce memory usage. |
| 298 | :param num_heads: the number of attention heads in each attention layer. |
| 299 | """ |
| 300 | |
| 301 | def __init__( |
| 302 | self, |
| 303 | in_channels, |
| 304 | model_channels, |
| 305 | out_channels, |
| 306 | num_res_blocks, |
| 307 | attention_resolutions, |
| 308 | dropout=0, |
| 309 | channel_mult=(1, 2, 4, 8), |
| 310 | conv_resample=True, |
| 311 | dims=2, |
| 312 | num_classes=None, |
| 313 | use_checkpoint=False, |
| 314 | num_heads=1, |
| 315 | num_heads_upsample=-1, |
| 316 | use_scale_shift_norm=False, |
| 317 | ): |
| 318 | super().__init__() |
| 319 | |
| 320 | if num_heads_upsample == -1: |
| 321 | num_heads_upsample = num_heads |
| 322 | |
| 323 | self.in_channels = in_channels |
| 324 | self.model_channels = model_channels |
| 325 | self.out_channels = out_channels |
| 326 | self.num_res_blocks = num_res_blocks |
| 327 | self.attention_resolutions = attention_resolutions |
| 328 | self.dropout = dropout |
| 329 | self.channel_mult = channel_mult |
| 330 | self.conv_resample = conv_resample |
| 331 | self.num_classes = num_classes |
| 332 | self.use_checkpoint = use_checkpoint |
| 333 | self.num_heads = num_heads |
| 334 | self.num_heads_upsample = num_heads_upsample |
| 335 |