| 53 | |
| 54 | |
| 55 | class Downsample(nn.Module): |
| 56 | def __init__(self, in_channels, with_conv): |
| 57 | super().__init__() |
| 58 | self.with_conv = with_conv |
| 59 | if self.with_conv: |
| 60 | # no asymmetric padding in torch conv, must do it ourselves |
| 61 | self.conv = torch.nn.Conv2d(in_channels, |
| 62 | in_channels, |
| 63 | kernel_size=3, |
| 64 | stride=2, |
| 65 | padding=0) |
| 66 | |
| 67 | def forward(self, x): |
| 68 | if self.with_conv: |
| 69 | pad = (0, 1, 0, 1) |
| 70 | x = torch.nn.functional.pad(x, pad, mode="constant", value=0) |
| 71 | x = self.conv(x) |
| 72 | else: |
| 73 | x = torch.nn.functional.avg_pool2d(x, kernel_size=2, stride=2) |
| 74 | return x |
| 75 | |
| 76 | |
| 77 | class ResnetBlock(nn.Module): |