An upsampling layer with an optional convolution. :param channels: channels in the inputs and outputs. :param use_conv: a bool determining if a convolution is applied. :param dims: determines if the signal is 1D, 2D, or 3D. If 3D, then upsampling occurs in the inne
| 48 | |
| 49 | |
| 50 | class Upsample(nn.Module): |
| 51 | """ |
| 52 | An upsampling layer with an optional convolution. |
| 53 | |
| 54 | :param channels: channels in the inputs and outputs. |
| 55 | :param use_conv: a bool determining if a convolution is applied. |
| 56 | :param dims: determines if the signal is 1D, 2D, or 3D. If 3D, then |
| 57 | upsampling occurs in the inner-two dimensions. |
| 58 | """ |
| 59 | |
| 60 | def __init__(self, channels, use_conv, dims=2): |
| 61 | super().__init__() |
| 62 | self.channels = channels |
| 63 | self.use_conv = use_conv |
| 64 | self.dims = dims |
| 65 | if use_conv: |
| 66 | self.conv = conv_nd(dims, channels, channels, 3, padding=1) |
| 67 | |
| 68 | def forward(self, x): |
| 69 | assert x.shape[1] == self.channels |
| 70 | if self.dims == 3: |
| 71 | x = F.interpolate( |
| 72 | x, (x.shape[2], x.shape[3] * 2, x.shape[4] * 2), mode="nearest" |
| 73 | ) |
| 74 | else: |
| 75 | x = F.interpolate(x, scale_factor=2, mode="nearest") |
| 76 | if self.use_conv: |
| 77 | x = self.conv(x) |
| 78 | return x |
| 79 | |
| 80 | |
| 81 | class Downsample(nn.Module): |