A downsampling layer with an optional convolution. :param channels: channels in the inputs and outputs. :param use_conv: a bool determining if a convolution is applied. :param dims: determines if the signal is 1D, 2D, or 3D. If 3D, then downsampling occurs in the i
| 79 | |
| 80 | |
| 81 | class Downsample(nn.Module): |
| 82 | """ |
| 83 | A downsampling layer with an optional convolution. |
| 84 | |
| 85 | :param channels: channels in the inputs and outputs. |
| 86 | :param use_conv: a bool determining if a convolution is applied. |
| 87 | :param dims: determines if the signal is 1D, 2D, or 3D. If 3D, then |
| 88 | downsampling occurs in the inner-two dimensions. |
| 89 | """ |
| 90 | |
| 91 | def __init__(self, channels, use_conv, dims=2): |
| 92 | super().__init__() |
| 93 | self.channels = channels |
| 94 | self.use_conv = use_conv |
| 95 | self.dims = dims |
| 96 | stride = 2 if dims != 3 else (1, 2, 2) |
| 97 | if use_conv: |
| 98 | self.op = conv_nd(dims, channels, channels, 3, stride=stride, padding=1) |
| 99 | else: |
| 100 | self.op = avg_pool_nd(stride) |
| 101 | |
| 102 | def forward(self, x): |
| 103 | assert x.shape[1] == self.channels |
| 104 | return self.op(x) |
| 105 | |
| 106 | |
| 107 | class ResBlock(TimestepBlock): |