The full UNet model with attention and timestep embedding. :param in_channels: channels in the input Tensor. :param model_channels: base channel count for the model. :param out_channels: channels in the output Tensor. :param num_res_blocks: number of residual blocks per downsam
| 394 | |
| 395 | |
| 396 | class UNetModel(nn.Module): |
| 397 | """ |
| 398 | The full UNet model with attention and timestep embedding. |
| 399 | |
| 400 | :param in_channels: channels in the input Tensor. |
| 401 | :param model_channels: base channel count for the model. |
| 402 | :param out_channels: channels in the output Tensor. |
| 403 | :param num_res_blocks: number of residual blocks per downsample. |
| 404 | :param attention_resolutions: a collection of downsample rates at which |
| 405 | attention will take place. May be a set, list, or tuple. |
| 406 | For example, if this contains 4, then at 4x downsampling, attention |
| 407 | will be used. |
| 408 | :param dropout: the dropout probability. |
| 409 | :param channel_mult: channel multiplier for each level of the UNet. |
| 410 | :param conv_resample: if True, use learned convolutions for upsampling and |
| 411 | downsampling. |
| 412 | :param dims: determines if the signal is 1D, 2D, or 3D. |
| 413 | :param num_classes: if specified (as an int), then this model will be |
| 414 | class-conditional with `num_classes` classes. |
| 415 | :param use_checkpoint: use gradient checkpointing to reduce memory usage. |
| 416 | :param num_heads: the number of attention heads in each attention layer. |
| 417 | :param num_heads_channels: if specified, ignore num_heads and instead use |
| 418 | a fixed channel width per attention head. |
| 419 | :param num_heads_upsample: works with num_heads to set a different number |
| 420 | of heads for upsampling. Deprecated. |
| 421 | :param use_scale_shift_norm: use a FiLM-like conditioning mechanism. |
| 422 | :param resblock_updown: use residual blocks for up/downsampling. |
| 423 | :param use_new_attention_order: use a different attention pattern for potentially |
| 424 | increased efficiency. |
| 425 | """ |
| 426 | |
| 427 | def __init__( |
| 428 | self, |
| 429 | image_size, |
| 430 | in_channels, |
| 431 | model_channels, |
| 432 | out_channels, |
| 433 | num_res_blocks, |
| 434 | attention_resolutions, |
| 435 | dropout=0, |
| 436 | channel_mult=(1, 2, 4, 8), |
| 437 | conv_resample=True, |
| 438 | dims=2, |
| 439 | num_classes=None, |
| 440 | use_checkpoint=False, |
| 441 | use_fp16=False, |
| 442 | num_heads=1, |
| 443 | num_head_channels=-1, |
| 444 | num_heads_upsample=-1, |
| 445 | use_scale_shift_norm=False, |
| 446 | resblock_updown=False, |
| 447 | use_new_attention_order=False, |
| 448 | ): |
| 449 | super().__init__() |
| 450 | |
| 451 | if num_heads_upsample == -1: |
| 452 | num_heads_upsample = num_heads |
| 453 |