(
self,
image_size,
in_channels,
model_channels,
out_channels,
num_res_blocks,
attention_resolutions,
dropout=0,
channel_mult=(1, 2, 4, 8),
conv_resample=True,
dims=2,
use_checkpoint=False,
use_fp16=False,
num_heads=1,
num_head_channels=-1,
num_heads_upsample=-1,
use_scale_shift_norm=False,
resblock_updown=False,
use_new_attention_order=False,
pool="adaptive",
*args,
**kwargs,
)
| 1014 | """ |
| 1015 | |
| 1016 | def __init__( |
| 1017 | self, |
| 1018 | image_size, |
| 1019 | in_channels, |
| 1020 | model_channels, |
| 1021 | out_channels, |
| 1022 | num_res_blocks, |
| 1023 | attention_resolutions, |
| 1024 | dropout=0, |
| 1025 | channel_mult=(1, 2, 4, 8), |
| 1026 | conv_resample=True, |
| 1027 | dims=2, |
| 1028 | use_checkpoint=False, |
| 1029 | use_fp16=False, |
| 1030 | num_heads=1, |
| 1031 | num_head_channels=-1, |
| 1032 | num_heads_upsample=-1, |
| 1033 | use_scale_shift_norm=False, |
| 1034 | resblock_updown=False, |
| 1035 | use_new_attention_order=False, |
| 1036 | pool="adaptive", |
| 1037 | *args, |
| 1038 | **kwargs, |
| 1039 | ): |
| 1040 | super().__init__() |
| 1041 | |
| 1042 | if num_heads_upsample == -1: |
| 1043 | num_heads_upsample = num_heads |
| 1044 | |
| 1045 | self.in_channels = in_channels |
| 1046 | self.model_channels = model_channels |
| 1047 | self.out_channels = out_channels |
| 1048 | self.num_res_blocks = num_res_blocks |
| 1049 | self.attention_resolutions = attention_resolutions |
| 1050 | self.dropout = dropout |
| 1051 | self.channel_mult = channel_mult |
| 1052 | self.conv_resample = conv_resample |
| 1053 | self.use_checkpoint = use_checkpoint |
| 1054 | self.dtype = th.float16 if use_fp16 else th.float32 |
| 1055 | self.num_heads = num_heads |
| 1056 | self.num_head_channels = num_head_channels |
| 1057 | self.num_heads_upsample = num_heads_upsample |
| 1058 | |
| 1059 | time_embed_dim = model_channels * 4 |
| 1060 | self.time_embed = nn.Sequential( |
| 1061 | linear(model_channels, time_embed_dim), |
| 1062 | nn.SiLU(), |
| 1063 | linear(time_embed_dim, time_embed_dim), |
| 1064 | ) |
| 1065 | |
| 1066 | self.input_blocks = nn.ModuleList( |
| 1067 | [TimestepEmbedSequential(conv_nd(dims, in_channels, model_channels, 3, padding=1))] |
| 1068 | ) |
| 1069 | self._feature_size = model_channels |
| 1070 | input_block_chans = [model_channels] |
| 1071 | ch = model_channels |
| 1072 | ds = 1 |
| 1073 | for level, mult in enumerate(channel_mult): |
nothing calls this directly
no test coverage detected