| 64 | |
| 65 | |
| 66 | class Resample(nn.Module): |
| 67 | |
| 68 | def __init__(self, dim, mode): |
| 69 | assert mode in ('none', 'upsample2d', 'upsample3d', 'downsample2d', |
| 70 | 'downsample3d') |
| 71 | super().__init__() |
| 72 | self.dim = dim |
| 73 | self.mode = mode |
| 74 | |
| 75 | # layers |
| 76 | if mode == 'upsample2d': |
| 77 | self.resample = nn.Sequential( |
| 78 | Upsample(scale_factor=(2., 2.), mode='nearest-exact'), |
| 79 | nn.Conv2d(dim, dim // 2, 3, padding=1)) |
| 80 | elif mode == 'upsample3d': |
| 81 | self.resample = nn.Sequential( |
| 82 | Upsample(scale_factor=(2., 2.), mode='nearest-exact'), |
| 83 | nn.Conv2d(dim, dim // 2, 3, padding=1)) |
| 84 | self.time_conv = CausalConv3d( |
| 85 | dim, dim * 2, (3, 1, 1), padding=(1, 0, 0)) |
| 86 | |
| 87 | elif mode == 'downsample2d': |
| 88 | self.resample = nn.Sequential( |
| 89 | nn.ZeroPad2d((0, 1, 0, 1)), |
| 90 | nn.Conv2d(dim, dim, 3, stride=(2, 2))) |
| 91 | elif mode == 'downsample3d': |
| 92 | self.resample = nn.Sequential( |
| 93 | nn.ZeroPad2d((0, 1, 0, 1)), |
| 94 | nn.Conv2d(dim, dim, 3, stride=(2, 2))) |
| 95 | self.time_conv = CausalConv3d( |
| 96 | dim, dim, (3, 1, 1), stride=(2, 1, 1), padding=(0, 0, 0)) |
| 97 | |
| 98 | else: |
| 99 | self.resample = nn.Identity() |
| 100 | |
| 101 | def forward(self, x, feat_cache=None, feat_idx=[0]): |
| 102 | b, c, t, h, w = x.size() |
| 103 | if self.mode == 'upsample3d': |
| 104 | if feat_cache is not None: |
| 105 | idx = feat_idx[0] |
| 106 | if feat_cache[idx] is None: |
| 107 | feat_cache[idx] = 'Rep' |
| 108 | feat_idx[0] += 1 |
| 109 | else: |
| 110 | |
| 111 | cache_x = x[:, :, -CACHE_T:, :, :].clone() |
| 112 | if cache_x.shape[2] < 2 and feat_cache[ |
| 113 | idx] is not None and feat_cache[idx] != 'Rep': |
| 114 | # cache last frame of last two chunk |
| 115 | cache_x = torch.cat([ |
| 116 | feat_cache[idx][:, :, -1, :, :].unsqueeze(2).to( |
| 117 | cache_x.device), cache_x |
| 118 | ], |
| 119 | dim=2) |
| 120 | if cache_x.shape[2] < 2 and feat_cache[ |
| 121 | idx] is not None and feat_cache[idx] == 'Rep': |
| 122 | cache_x = torch.cat([ |
| 123 | torch.zeros_like(cache_x).to(cache_x.device), |