| 34 | |
| 35 | |
| 36 | class Upsample(nn.Module): |
| 37 | def __init__(self, in_channels, with_conv): |
| 38 | super().__init__() |
| 39 | self.with_conv = with_conv |
| 40 | if self.with_conv: |
| 41 | self.conv = torch.nn.Conv2d(in_channels, |
| 42 | in_channels, |
| 43 | kernel_size=3, |
| 44 | stride=1, |
| 45 | padding=1) |
| 46 | |
| 47 | def forward(self, x): |
| 48 | x = torch.nn.functional.interpolate( |
| 49 | x, scale_factor=2.0, mode="nearest") |
| 50 | if self.with_conv: |
| 51 | x = self.conv(x) |
| 52 | return x |
| 53 | |
| 54 | |
| 55 | class Downsample(nn.Module): |