Reconstruct video tensors from patch embeddings. Args: x (List[Tensor]): List of patchified features, each with shape [L, C_out * prod(patch_size)] grid_sizes (Tensor): Original spatial-temporal grid dimensions before patching
(self, x, grid_sizes)
| 856 | return [u.float() for u in x] |
| 857 | |
| 858 | def unpatchify(self, x, grid_sizes): |
| 859 | """ |
| 860 | Reconstruct video tensors from patch embeddings. |
| 861 | |
| 862 | Args: |
| 863 | x (List[Tensor]): |
| 864 | List of patchified features, each with shape [L, C_out * prod(patch_size)] |
| 865 | grid_sizes (Tensor): |
| 866 | Original spatial-temporal grid dimensions before patching, |
| 867 | shape [B, 3] (3 dimensions correspond to F_patches, H_patches, W_patches) |
| 868 | |
| 869 | Returns: |
| 870 | List[Tensor]: |
| 871 | Reconstructed video tensors with shape [C_out, F, H / 8, W / 8] |
| 872 | """ |
| 873 | |
| 874 | c = self.out_dim |
| 875 | out = [] |
| 876 | for u, v in zip(x, grid_sizes.tolist()): |
| 877 | u = u[:math.prod(v)].view(*v, *self.patch_size, c) |
| 878 | u = torch.einsum('fhwpqrc->cfphqwr', u) |
| 879 | u = u.reshape(c, *[i * j for i, j in zip(v, self.patch_size)]) |
| 880 | out.append(u) |
| 881 | return out |
| 882 | |
| 883 | def init_weights(self): |
| 884 | r""" |