| 109 | self.resample = nn.Identity() |
| 110 | |
| 111 | def forward(self, x, feat_cache=None, feat_idx=[0]): |
| 112 | b, c, t, h, w = x.size() |
| 113 | if self.mode == "upsample3d": |
| 114 | if feat_cache is not None: |
| 115 | idx = feat_idx[0] |
| 116 | if feat_cache[idx] is None: |
| 117 | feat_cache[idx] = "Rep" |
| 118 | feat_idx[0] += 1 |
| 119 | else: |
| 120 | cache_x = x[:, :, -CACHE_T:, :, :].clone() |
| 121 | if (cache_x.shape[2] < 2 and feat_cache[idx] is not None and |
| 122 | feat_cache[idx] != "Rep"): |
| 123 | # cache last frame of last two chunk |
| 124 | cache_x = torch.cat( |
| 125 | [ |
| 126 | feat_cache[idx][:, :, -1, :, :].unsqueeze(2).to( |
| 127 | cache_x.device), |
| 128 | cache_x, |
| 129 | ], |
| 130 | dim=2, |
| 131 | ) |
| 132 | if (cache_x.shape[2] < 2 and feat_cache[idx] is not None and |
| 133 | feat_cache[idx] == "Rep"): |
| 134 | cache_x = torch.cat( |
| 135 | [ |
| 136 | torch.zeros_like(cache_x).to(cache_x.device), |
| 137 | cache_x |
| 138 | ], |
| 139 | dim=2, |
| 140 | ) |
| 141 | if feat_cache[idx] == "Rep": |
| 142 | x = self.time_conv(x) |
| 143 | else: |
| 144 | x = self.time_conv(x, feat_cache[idx]) |
| 145 | feat_cache[idx] = cache_x |
| 146 | feat_idx[0] += 1 |
| 147 | x = x.reshape(b, 2, c, t, h, w) |
| 148 | x = torch.stack((x[:, 0, :, :, :, :], x[:, 1, :, :, :, :]), |
| 149 | 3) |
| 150 | x = x.reshape(b, c, t * 2, h, w) |
| 151 | t = x.shape[2] |
| 152 | x = rearrange(x, "b c t h w -> (b t) c h w") |
| 153 | x = self.resample(x) |
| 154 | x = rearrange(x, "(b t) c h w -> b c t h w", t=t) |
| 155 | |
| 156 | if self.mode == "downsample3d": |
| 157 | if feat_cache is not None: |
| 158 | idx = feat_idx[0] |
| 159 | if feat_cache[idx] is None: |
| 160 | feat_cache[idx] = x.clone() |
| 161 | feat_idx[0] += 1 |
| 162 | else: |
| 163 | cache_x = x[:, :, -1:, :, :].clone() |
| 164 | x = self.time_conv( |
| 165 | torch.cat([feat_cache[idx][:, :, -1:, :, :], x], 2)) |
| 166 | feat_cache[idx] = cache_x |
| 167 | feat_idx[0] += 1 |
| 168 | return x |