3D pixel shuffle.
(x: torch.Tensor, scale_factor: int)
| 2 | |
| 3 | |
| 4 | def pixel_shuffle_3d(x: torch.Tensor, scale_factor: int) -> torch.Tensor: |
| 5 | """ |
| 6 | 3D pixel shuffle. |
| 7 | """ |
| 8 | B, C, H, W, D = x.shape |
| 9 | C_ = C // scale_factor**3 |
| 10 | x = x.reshape(B, C_, scale_factor, scale_factor, scale_factor, H, W, D) |
| 11 | x = x.permute(0, 1, 5, 2, 6, 3, 7, 4) |
| 12 | x = x.reshape(B, C_, H*scale_factor, W*scale_factor, D*scale_factor) |
| 13 | return x |
| 14 | |
| 15 | |
| 16 | def patchify(x: torch.Tensor, patch_size: int): |