(self, coords, batch_size)
| 465 | return torch.Size(shape) |
| 466 | |
| 467 | def __cal_layout(self, coords, batch_size): |
| 468 | seq_len = torch.bincount(coords[:, 0], minlength=batch_size) |
| 469 | offset = torch.cumsum(seq_len, dim=0) |
| 470 | layout = [slice((offset[i] - seq_len[i]).item(), offset[i].item()) for i in range(batch_size)] |
| 471 | return layout |
| 472 | |
| 473 | def __cal_spatial_shape(self, coords): |
| 474 | return torch.Size((coords[:, 1:].max(0)[0] + 1).tolist()) |