A downsampling layer with an optional convolution. :param channels: channels in the inputs and outputs. :param use_conv: a bool determining if a convolution is applied. :param dims: determines if the signal is 1D, 2D, or 3D. If 3D, then downsampling occurs in the i
| 111 | |
| 112 | |
| 113 | class Downsample(nn.Module): |
| 114 | """ |
| 115 | A downsampling layer with an optional convolution. |
| 116 | |
| 117 | :param channels: channels in the inputs and outputs. |
| 118 | :param use_conv: a bool determining if a convolution is applied. |
| 119 | :param dims: determines if the signal is 1D, 2D, or 3D. If 3D, then |
| 120 | downsampling occurs in the inner-two dimensions. |
| 121 | """ |
| 122 | |
| 123 | def __init__(self, channels, use_conv, dims=2, out_channels=None): |
| 124 | super().__init__() |
| 125 | self.channels = channels |
| 126 | self.out_channels = out_channels or channels |
| 127 | self.use_conv = use_conv |
| 128 | self.dims = dims |
| 129 | stride = 2 if dims != 3 else (1, 2, 2) |
| 130 | if use_conv: |
| 131 | self.op = conv_nd( |
| 132 | dims, self.channels, self.out_channels, 3, stride=stride, padding=1 |
| 133 | ) |
| 134 | else: |
| 135 | assert self.channels == self.out_channels |
| 136 | self.op = avg_pool_nd(dims, kernel_size=stride, stride=stride) |
| 137 | |
| 138 | def forward(self, x): |
| 139 | assert x.shape[1] == self.channels |
| 140 | return self.op(x) |
| 141 | |
| 142 | |
| 143 | class ResBlock(TimestepBlock): |