| 607 | |
| 608 | |
| 609 | class SerializedPooling(PointModule): |
| 610 | def __init__( |
| 611 | self, |
| 612 | in_channels, |
| 613 | out_channels, |
| 614 | stride=2, |
| 615 | norm_layer=None, |
| 616 | act_layer=None, |
| 617 | reduce="max", |
| 618 | shuffle_orders=True, |
| 619 | traceable=True, # record parent and cluster |
| 620 | ): |
| 621 | super().__init__() |
| 622 | self.in_channels = in_channels |
| 623 | self.out_channels = out_channels |
| 624 | |
| 625 | assert stride == 2 ** (math.ceil(stride) - 1).bit_length() # 2, 4, 8 |
| 626 | # TODO: add support to grid pool (any stride) |
| 627 | self.stride = stride |
| 628 | assert reduce in ["sum", "mean", "min", "max"] |
| 629 | self.reduce = reduce |
| 630 | self.shuffle_orders = shuffle_orders |
| 631 | self.traceable = traceable |
| 632 | |
| 633 | self.proj = nn.Linear(in_channels, out_channels) |
| 634 | if norm_layer is not None: |
| 635 | self.norm = PointSequential(norm_layer(out_channels)) |
| 636 | if act_layer is not None: |
| 637 | self.act = PointSequential(act_layer()) |
| 638 | |
| 639 | def forward(self, point: Point): |
| 640 | pooling_depth = (math.ceil(self.stride) - 1).bit_length() |
| 641 | if pooling_depth > point.serialized_depth: |
| 642 | pooling_depth = 0 |
| 643 | assert { |
| 644 | "serialized_code", |
| 645 | "serialized_order", |
| 646 | "serialized_inverse", |
| 647 | "serialized_depth", |
| 648 | }.issubset( |
| 649 | point.keys() |
| 650 | ), "Run point.serialization() point cloud before SerializedPooling" |
| 651 | |
| 652 | code = point.serialized_code >> pooling_depth * 3 |
| 653 | code_, cluster, counts = torch.unique( |
| 654 | code[0], |
| 655 | sorted=True, |
| 656 | return_inverse=True, |
| 657 | return_counts=True, |
| 658 | ) |
| 659 | # indices of point sorted by cluster, for torch_scatter.segment_csr |
| 660 | _, indices = torch.sort(cluster) |
| 661 | # index pointer for sorted point, for torch_scatter.segment_csr |
| 662 | idx_ptr = torch.cat([counts.new_zeros(1), torch.cumsum(counts, dim=0)]) |
| 663 | # head_indices of each cluster, for reduce attr e.g. code, batch |
| 664 | head_indices = indices[idx_ptr[:-1]] |
| 665 | # generate down code, order, inverse |
| 666 | code = code[:, head_indices] |