The half UNet model with attention and timestep embedding. For usage, see UNet.
| 681 | |
| 682 | |
| 683 | class EncoderUNetModel(nn.Module): |
| 684 | """ |
| 685 | The half UNet model with attention and timestep embedding. |
| 686 | |
| 687 | For usage, see UNet. |
| 688 | """ |
| 689 | |
| 690 | def __init__( |
| 691 | self, |
| 692 | image_size, |
| 693 | in_channels, |
| 694 | model_channels, |
| 695 | out_channels, |
| 696 | num_res_blocks, |
| 697 | attention_resolutions, |
| 698 | dropout=0, |
| 699 | channel_mult=(1, 2, 4, 8), |
| 700 | conv_resample=True, |
| 701 | dims=2, |
| 702 | use_checkpoint=False, |
| 703 | use_fp16=False, |
| 704 | num_heads=1, |
| 705 | num_head_channels=-1, |
| 706 | num_heads_upsample=-1, |
| 707 | use_scale_shift_norm=False, |
| 708 | resblock_updown=False, |
| 709 | use_new_attention_order=False, |
| 710 | pool="adaptive", |
| 711 | ): |
| 712 | super().__init__() |
| 713 | |
| 714 | if num_heads_upsample == -1: |
| 715 | num_heads_upsample = num_heads |
| 716 | |
| 717 | self.in_channels = in_channels |
| 718 | self.model_channels = model_channels |
| 719 | self.out_channels = out_channels |
| 720 | self.num_res_blocks = num_res_blocks |
| 721 | self.attention_resolutions = attention_resolutions |
| 722 | self.dropout = dropout |
| 723 | self.channel_mult = channel_mult |
| 724 | self.conv_resample = conv_resample |
| 725 | self.use_checkpoint = use_checkpoint |
| 726 | self.dtype = th.float16 if use_fp16 else th.float32 |
| 727 | self.num_heads = num_heads |
| 728 | self.num_head_channels = num_head_channels |
| 729 | self.num_heads_upsample = num_heads_upsample |
| 730 | |
| 731 | time_embed_dim = model_channels * 4 |
| 732 | self.time_embed = nn.Sequential( |
| 733 | linear(model_channels, time_embed_dim), |
| 734 | nn.SiLU(), |
| 735 | linear(time_embed_dim, time_embed_dim), |
| 736 | ) |
| 737 | |
| 738 | ch = int(channel_mult[0] * model_channels) |
| 739 | self.input_blocks = nn.ModuleList( |
| 740 | [TimestepEmbedSequential(conv_nd(dims, in_channels, ch, 3, padding=1))] |