| 637 | self.act = PointSequential(act_layer()) |
| 638 | |
| 639 | def forward(self, point: Point): |
| 640 | pooling_depth = (math.ceil(self.stride) - 1).bit_length() |
| 641 | if pooling_depth > point.serialized_depth: |
| 642 | pooling_depth = 0 |
| 643 | assert { |
| 644 | "serialized_code", |
| 645 | "serialized_order", |
| 646 | "serialized_inverse", |
| 647 | "serialized_depth", |
| 648 | }.issubset( |
| 649 | point.keys() |
| 650 | ), "Run point.serialization() point cloud before SerializedPooling" |
| 651 | |
| 652 | code = point.serialized_code >> pooling_depth * 3 |
| 653 | code_, cluster, counts = torch.unique( |
| 654 | code[0], |
| 655 | sorted=True, |
| 656 | return_inverse=True, |
| 657 | return_counts=True, |
| 658 | ) |
| 659 | # indices of point sorted by cluster, for torch_scatter.segment_csr |
| 660 | _, indices = torch.sort(cluster) |
| 661 | # index pointer for sorted point, for torch_scatter.segment_csr |
| 662 | idx_ptr = torch.cat([counts.new_zeros(1), torch.cumsum(counts, dim=0)]) |
| 663 | # head_indices of each cluster, for reduce attr e.g. code, batch |
| 664 | head_indices = indices[idx_ptr[:-1]] |
| 665 | # generate down code, order, inverse |
| 666 | code = code[:, head_indices] |
| 667 | order = torch.argsort(code) |
| 668 | inverse = torch.zeros_like(order).scatter_( |
| 669 | dim=1, |
| 670 | index=order, |
| 671 | src=torch.arange(0, code.shape[1], device=order.device).repeat( |
| 672 | code.shape[0], 1 |
| 673 | ), |
| 674 | ) |
| 675 | |
| 676 | if self.shuffle_orders: |
| 677 | perm = torch.randperm(code.shape[0]) |
| 678 | code = code[perm] |
| 679 | order = order[perm] |
| 680 | inverse = inverse[perm] |
| 681 | |
| 682 | # collect information |
| 683 | point_dict = Dict( |
| 684 | feat=torch_scatter.segment_csr( |
| 685 | self.proj(point.feat)[indices], idx_ptr, reduce=self.reduce |
| 686 | ), |
| 687 | coord=torch_scatter.segment_csr( |
| 688 | point.coord[indices], idx_ptr, reduce="mean" |
| 689 | ), |
| 690 | grid_coord=point.grid_coord[head_indices] >> pooling_depth, |
| 691 | serialized_code=code, |
| 692 | serialized_order=order, |
| 693 | serialized_inverse=inverse, |
| 694 | serialized_depth=point.serialized_depth - pooling_depth, |
| 695 | batch=point.batch[head_indices], |
| 696 | ) |