(self, dim, mode)
| 66 | class Resample(nn.Module): |
| 67 | |
| 68 | def __init__(self, dim, mode): |
| 69 | assert mode in ('none', 'upsample2d', 'upsample3d', 'downsample2d', |
| 70 | 'downsample3d') |
| 71 | super().__init__() |
| 72 | self.dim = dim |
| 73 | self.mode = mode |
| 74 | |
| 75 | # layers |
| 76 | if mode == 'upsample2d': |
| 77 | self.resample = nn.Sequential( |
| 78 | Upsample(scale_factor=(2., 2.), mode='nearest-exact'), |
| 79 | nn.Conv2d(dim, dim // 2, 3, padding=1)) |
| 80 | elif mode == 'upsample3d': |
| 81 | self.resample = nn.Sequential( |
| 82 | Upsample(scale_factor=(2., 2.), mode='nearest-exact'), |
| 83 | nn.Conv2d(dim, dim // 2, 3, padding=1)) |
| 84 | self.time_conv = CausalConv3d( |
| 85 | dim, dim * 2, (3, 1, 1), padding=(1, 0, 0)) |
| 86 | |
| 87 | elif mode == 'downsample2d': |
| 88 | self.resample = nn.Sequential( |
| 89 | nn.ZeroPad2d((0, 1, 0, 1)), |
| 90 | nn.Conv2d(dim, dim, 3, stride=(2, 2))) |
| 91 | elif mode == 'downsample3d': |
| 92 | self.resample = nn.Sequential( |
| 93 | nn.ZeroPad2d((0, 1, 0, 1)), |
| 94 | nn.Conv2d(dim, dim, 3, stride=(2, 2))) |
| 95 | self.time_conv = CausalConv3d( |
| 96 | dim, dim, (3, 1, 1), stride=(2, 1, 1), padding=(0, 0, 0)) |
| 97 | |
| 98 | else: |
| 99 | self.resample = nn.Identity() |
| 100 | |
| 101 | def forward(self, x, feat_cache=None, feat_idx=[0]): |
| 102 | b, c, t, h, w = x.size() |
nothing calls this directly
no test coverage detected