| 388 | self.repeats = out_channels * self.factor // in_channels |
| 389 | |
| 390 | def forward(self, x: torch.Tensor, first_chunk=False) -> torch.Tensor: |
| 391 | x = x.repeat_interleave(self.repeats, dim=1) |
| 392 | x = x.view( |
| 393 | x.size(0), |
| 394 | self.out_channels, |
| 395 | self.factor_t, |
| 396 | self.factor_s, |
| 397 | self.factor_s, |
| 398 | x.size(2), |
| 399 | x.size(3), |
| 400 | x.size(4), |
| 401 | ) |
| 402 | x = x.permute(0, 1, 5, 2, 6, 3, 7, 4).contiguous() |
| 403 | x = x.view( |
| 404 | x.size(0), |
| 405 | self.out_channels, |
| 406 | x.size(2) * self.factor_t, |
| 407 | x.size(4) * self.factor_s, |
| 408 | x.size(6) * self.factor_s, |
| 409 | ) |
| 410 | if first_chunk: |
| 411 | x = x[:, :, self.factor_t - 1:, :, :] |
| 412 | return x |
| 413 | |
| 414 | |
| 415 | class Down_ResidualBlock(nn.Module): |