(x, patch_size)
| 278 | |
| 279 | |
| 280 | def patchify(x, patch_size): |
| 281 | if patch_size == 1: |
| 282 | return x |
| 283 | if x.dim() == 4: |
| 284 | x = rearrange( |
| 285 | x, "b c (h q) (w r) -> b (c r q) h w", q=patch_size, r=patch_size) |
| 286 | elif x.dim() == 5: |
| 287 | x = rearrange( |
| 288 | x, |
| 289 | "b c f (h q) (w r) -> b (c r q) f h w", |
| 290 | q=patch_size, |
| 291 | r=patch_size, |
| 292 | ) |
| 293 | else: |
| 294 | raise ValueError(f"Invalid input shape: {x.shape}") |
| 295 | |
| 296 | return x |
| 297 | |
| 298 | |
| 299 | def unpatchify(x, patch_size): |