| 326 | |
| 327 | @torch.no_grad() |
| 328 | def _expand_for_blocking(idxs, blocking): |
| 329 | idxs = torch.reshape(idxs, [idxs.size()[0], 1, 2]).repeat(1, blocking, 1) |
| 330 | |
| 331 | idxs[:, :, 1] *= blocking |
| 332 | idxs[:, :, 1] += torch.reshape(torch.arange(blocking, device=idxs.device), [1, blocking]) |
| 333 | |
| 334 | idxs = torch.reshape(idxs, [idxs.size()[0], 1, blocking, 2]) |
| 335 | idxs = idxs.repeat(1, blocking, 1, 1) |
| 336 | |
| 337 | idxs[:, :, :, 0] *= blocking |
| 338 | idxs[:, :, :, 0] += torch.reshape(torch.arange(blocking, device=idxs.device), [1, blocking, 1]) |
| 339 | idxs = torch.reshape(idxs, [-1, 2]) |
| 340 | return idxs |
| 341 | |
| 342 | |
| 343 | @torch.no_grad() |