r""" Reconstruct video tensors from patch embeddings. Args: x (List[Tensor]): List of patchified features, each with shape [L, C_out * prod(patch_size)] grid_sizes (Tensor): Original spatial-temporal grid dimensions before patc
(self, x, grid_sizes)
| 582 | return [u.float() for u in x] |
| 583 | |
| 584 | def unpatchify(self, x, grid_sizes): |
| 585 | r""" |
| 586 | Reconstruct video tensors from patch embeddings. |
| 587 | |
| 588 | Args: |
| 589 | x (List[Tensor]): |
| 590 | List of patchified features, each with shape [L, C_out * prod(patch_size)] |
| 591 | grid_sizes (Tensor): |
| 592 | Original spatial-temporal grid dimensions before patching, |
| 593 | shape [B, 3] (3 dimensions correspond to F_patches, H_patches, W_patches) |
| 594 | |
| 595 | Returns: |
| 596 | List[Tensor]: |
| 597 | Reconstructed video tensors with shape [C_out, F, H / 8, W / 8] |
| 598 | """ |
| 599 | |
| 600 | c = self.out_dim |
| 601 | out = [] |
| 602 | for u, v in zip(x, grid_sizes.tolist()): |
| 603 | u = u[:math.prod(v)].view(*v, *self.patch_size, c) |
| 604 | u = torch.einsum('fhwpqrc->cfphqwr', u) |
| 605 | u = u.reshape(c, *[i * j for i, j in zip(v, self.patch_size)]) |
| 606 | out.append(u) |
| 607 | return out |
| 608 | |
| 609 | def init_weights(self): |
| 610 | r""" |
no outgoing calls
no test coverage detected