MCPcopy
hub / github.com/Pointcept/PointTransformerV3 / SerializedPooling

Class SerializedPooling

model.py:609–712  ·  view source on GitHub ↗

Source from the content-addressed store, hash-verified

607
608
609class SerializedPooling(PointModule):
610 def __init__(
611 self,
612 in_channels,
613 out_channels,
614 stride=2,
615 norm_layer=None,
616 act_layer=None,
617 reduce="max",
618 shuffle_orders=True,
619 traceable=True, # record parent and cluster
620 ):
621 super().__init__()
622 self.in_channels = in_channels
623 self.out_channels = out_channels
624
625 assert stride == 2 ** (math.ceil(stride) - 1).bit_length() # 2, 4, 8
626 # TODO: add support to grid pool (any stride)
627 self.stride = stride
628 assert reduce in ["sum", "mean", "min", "max"]
629 self.reduce = reduce
630 self.shuffle_orders = shuffle_orders
631 self.traceable = traceable
632
633 self.proj = nn.Linear(in_channels, out_channels)
634 if norm_layer is not None:
635 self.norm = PointSequential(norm_layer(out_channels))
636 if act_layer is not None:
637 self.act = PointSequential(act_layer())
638
639 def forward(self, point: Point):
640 pooling_depth = (math.ceil(self.stride) - 1).bit_length()
641 if pooling_depth > point.serialized_depth:
642 pooling_depth = 0
643 assert {
644 "serialized_code",
645 "serialized_order",
646 "serialized_inverse",
647 "serialized_depth",
648 }.issubset(
649 point.keys()
650 ), "Run point.serialization() point cloud before SerializedPooling"
651
652 code = point.serialized_code >> pooling_depth * 3
653 code_, cluster, counts = torch.unique(
654 code[0],
655 sorted=True,
656 return_inverse=True,
657 return_counts=True,
658 )
659 # indices of point sorted by cluster, for torch_scatter.segment_csr
660 _, indices = torch.sort(cluster)
661 # index pointer for sorted point, for torch_scatter.segment_csr
662 idx_ptr = torch.cat([counts.new_zeros(1), torch.cumsum(counts, dim=0)])
663 # head_indices of each cluster, for reduce attr e.g. code, batch
664 head_indices = indices[idx_ptr[:-1]]
665 # generate down code, order, inverse
666 code = code[:, head_indices]

Callers 1

__init__Method · 0.85

Calls

no outgoing calls

Tested by

no test coverage detected