(self, in_channels, with_conv)
| 18 | |
| 19 | class Upsample(nn.Module): |
| 20 | def __init__(self, in_channels, with_conv): |
| 21 | super().__init__() |
| 22 | self.with_conv = with_conv |
| 23 | if self.with_conv: |
| 24 | self.conv = torch.nn.Conv2d( |
| 25 | in_channels, in_channels, kernel_size=3, stride=1, padding=1 |
| 26 | ) |
| 27 | |
| 28 | def forward(self, x): |
| 29 | x = torch.nn.functional.interpolate(x, scale_factor=2.0, mode="nearest") |