A residual block that can optionally change the number of channels. :param channels: the number of input channels. :param emb_channels: the number of timestep embedding channels. :param dropout: the rate of dropout. :param out_channels: if specified, the number of out channels.
| 141 | |
| 142 | |
| 143 | class ResBlock(TimestepBlock): |
| 144 | """ |
| 145 | A residual block that can optionally change the number of channels. |
| 146 | |
| 147 | :param channels: the number of input channels. |
| 148 | :param emb_channels: the number of timestep embedding channels. |
| 149 | :param dropout: the rate of dropout. |
| 150 | :param out_channels: if specified, the number of out channels. |
| 151 | :param use_conv: if True and out_channels is specified, use a spatial |
| 152 | convolution instead of a smaller 1x1 convolution to change the |
| 153 | channels in the skip connection. |
| 154 | :param dims: determines if the signal is 1D, 2D, or 3D. |
| 155 | :param use_checkpoint: if True, use gradient checkpointing on this module. |
| 156 | :param up: if True, use this block for upsampling. |
| 157 | :param down: if True, use this block for downsampling. |
| 158 | """ |
| 159 | |
| 160 | def __init__( |
| 161 | self, |
| 162 | channels, |
| 163 | emb_channels, |
| 164 | dropout, |
| 165 | out_channels=None, |
| 166 | use_conv=False, |
| 167 | use_scale_shift_norm=False, |
| 168 | dims=2, |
| 169 | use_checkpoint=False, |
| 170 | up=False, |
| 171 | down=False, |
| 172 | ): |
| 173 | super().__init__() |
| 174 | self.channels = channels |
| 175 | self.emb_channels = emb_channels |
| 176 | self.dropout = dropout |
| 177 | self.out_channels = out_channels or channels |
| 178 | self.use_conv = use_conv |
| 179 | self.use_checkpoint = use_checkpoint |
| 180 | self.use_scale_shift_norm = use_scale_shift_norm |
| 181 | |
| 182 | self.in_layers = nn.Sequential( |
| 183 | normalization(channels), |
| 184 | nn.SiLU(), |
| 185 | conv_nd(dims, channels, self.out_channels, 3, padding=1), |
| 186 | ) |
| 187 | |
| 188 | self.updown = up or down |
| 189 | |
| 190 | if up: |
| 191 | self.h_upd = Upsample(channels, False, dims) |
| 192 | self.x_upd = Upsample(channels, False, dims) |
| 193 | elif down: |
| 194 | self.h_upd = Downsample(channels, False, dims) |
| 195 | self.x_upd = Downsample(channels, False, dims) |
| 196 | else: |
| 197 | self.h_upd = self.x_upd = nn.Identity() |
| 198 | |
| 199 | self.emb_layers = nn.Sequential( |
| 200 | nn.SiLU(), |