(
self,
channels,
emb_channels,
dropout,
out_channels=None,
use_conv=False,
use_scale_shift_norm=False,
dims=2,
use_checkpoint=False,
up=False,
down=False,
)
| 176 | """ |
| 177 | |
| 178 | def __init__( |
| 179 | self, |
| 180 | channels, |
| 181 | emb_channels, |
| 182 | dropout, |
| 183 | out_channels=None, |
| 184 | use_conv=False, |
| 185 | use_scale_shift_norm=False, |
| 186 | dims=2, |
| 187 | use_checkpoint=False, |
| 188 | up=False, |
| 189 | down=False, |
| 190 | ): |
| 191 | super().__init__() |
| 192 | self.channels = channels |
| 193 | self.emb_channels = emb_channels |
| 194 | self.dropout = dropout |
| 195 | self.out_channels = out_channels or channels |
| 196 | self.use_conv = use_conv |
| 197 | self.use_checkpoint = use_checkpoint |
| 198 | self.use_scale_shift_norm = use_scale_shift_norm |
| 199 | |
| 200 | self.in_layers = nn.Sequential( |
| 201 | normalization(channels), |
| 202 | nn.SiLU(), |
| 203 | conv_nd(dims, channels, self.out_channels, 3, padding=1), |
| 204 | ) |
| 205 | |
| 206 | self.updown = up or down |
| 207 | |
| 208 | if up: |
| 209 | self.h_upd = Upsample(channels, False, dims) |
| 210 | self.x_upd = Upsample(channels, False, dims) |
| 211 | elif down: |
| 212 | self.h_upd = Downsample(channels, False, dims) |
| 213 | self.x_upd = Downsample(channels, False, dims) |
| 214 | else: |
| 215 | self.h_upd = self.x_upd = nn.Identity() |
| 216 | |
| 217 | self.emb_layers = nn.Sequential( |
| 218 | nn.SiLU(), |
| 219 | linear( |
| 220 | emb_channels, |
| 221 | 2 * self.out_channels if use_scale_shift_norm else self.out_channels, |
| 222 | ), |
| 223 | ) |
| 224 | self.out_layers = nn.Sequential( |
| 225 | normalization(self.out_channels), |
| 226 | nn.SiLU(), |
| 227 | nn.Dropout(p=dropout), |
| 228 | zero_module( |
| 229 | conv_nd(dims, self.out_channels, self.out_channels, 3, padding=1) |
| 230 | ), |
| 231 | ) |
| 232 | |
| 233 | if self.out_channels == channels: |
| 234 | self.skip_connection = nn.Identity() |
| 235 | elif use_conv: |
nothing calls this directly
no test coverage detected