(
self,
channels,
emb_channels,
dropout,
out_channels=None,
use_conv=False,
use_scale_shift_norm=False,
dims=2,
use_checkpoint=False,
)
| 120 | """ |
| 121 | |
| 122 | def __init__( |
| 123 | self, |
| 124 | channels, |
| 125 | emb_channels, |
| 126 | dropout, |
| 127 | out_channels=None, |
| 128 | use_conv=False, |
| 129 | use_scale_shift_norm=False, |
| 130 | dims=2, |
| 131 | use_checkpoint=False, |
| 132 | ): |
| 133 | super().__init__() |
| 134 | self.channels = channels |
| 135 | self.emb_channels = emb_channels |
| 136 | self.dropout = dropout |
| 137 | self.out_channels = out_channels or channels |
| 138 | self.use_conv = use_conv |
| 139 | self.use_checkpoint = use_checkpoint |
| 140 | self.use_scale_shift_norm = use_scale_shift_norm |
| 141 | |
| 142 | self.in_layers = nn.Sequential( |
| 143 | normalization(channels), |
| 144 | SiLU(), |
| 145 | conv_nd(dims, channels, self.out_channels, 3, padding=1), |
| 146 | ) |
| 147 | self.emb_layers = nn.Sequential( |
| 148 | SiLU(), |
| 149 | linear( |
| 150 | emb_channels, |
| 151 | 2 * self.out_channels if use_scale_shift_norm else self.out_channels, |
| 152 | ), |
| 153 | ) |
| 154 | self.out_layers = nn.Sequential( |
| 155 | normalization(self.out_channels), |
| 156 | SiLU(), |
| 157 | nn.Dropout(p=dropout), |
| 158 | zero_module( |
| 159 | conv_nd(dims, self.out_channels, self.out_channels, 3, padding=1) |
| 160 | ), |
| 161 | ) |
| 162 | |
| 163 | if self.out_channels == channels: |
| 164 | self.skip_connection = nn.Identity() |
| 165 | elif use_conv: |
| 166 | self.skip_connection = conv_nd( |
| 167 | dims, channels, self.out_channels, 3, padding=1 |
| 168 | ) |
| 169 | else: |
| 170 | self.skip_connection = conv_nd(dims, channels, self.out_channels, 1) |
| 171 | |
| 172 | def forward(self, x, emb): |
| 173 | """ |
nothing calls this directly
no test coverage detected