MCPcopy Index your code
hub / github.com/Pointcept/PointTransformerV3 / forward

Method forward

model.py:639–712  ·  view source on GitHub ↗
(self, point: Point)

Source from the content-addressed store, hash-verified

637 self.act = PointSequential(act_layer())
638
639 def forward(self, point: Point):
640 pooling_depth = (math.ceil(self.stride) - 1).bit_length()
641 if pooling_depth > point.serialized_depth:
642 pooling_depth = 0
643 assert {
644 "serialized_code",
645 "serialized_order",
646 "serialized_inverse",
647 "serialized_depth",
648 }.issubset(
649 point.keys()
650 ), "Run point.serialization() point cloud before SerializedPooling"
651
652 code = point.serialized_code >> pooling_depth * 3
653 code_, cluster, counts = torch.unique(
654 code[0],
655 sorted=True,
656 return_inverse=True,
657 return_counts=True,
658 )
659 # indices of point sorted by cluster, for torch_scatter.segment_csr
660 _, indices = torch.sort(cluster)
661 # index pointer for sorted point, for torch_scatter.segment_csr
662 idx_ptr = torch.cat([counts.new_zeros(1), torch.cumsum(counts, dim=0)])
663 # head_indices of each cluster, for reduce attr e.g. code, batch
664 head_indices = indices[idx_ptr[:-1]]
665 # generate down code, order, inverse
666 code = code[:, head_indices]
667 order = torch.argsort(code)
668 inverse = torch.zeros_like(order).scatter_(
669 dim=1,
670 index=order,
671 src=torch.arange(0, code.shape[1], device=order.device).repeat(
672 code.shape[0], 1
673 ),
674 )
675
676 if self.shuffle_orders:
677 perm = torch.randperm(code.shape[0])
678 code = code[perm]
679 order = order[perm]
680 inverse = inverse[perm]
681
682 # collect information
683 point_dict = Dict(
684 feat=torch_scatter.segment_csr(
685 self.proj(point.feat)[indices], idx_ptr, reduce=self.reduce
686 ),
687 coord=torch_scatter.segment_csr(
688 point.coord[indices], idx_ptr, reduce="mean"
689 ),
690 grid_coord=point.grid_coord[head_indices] >> pooling_depth,
691 serialized_code=code,
692 serialized_order=order,
693 serialized_inverse=inverse,
694 serialized_depth=point.serialized_depth - pooling_depth,
695 batch=point.batch[head_indices],
696 )

Callers

nothing calls this directly

Calls 2

sparsifyMethod · 0.95
PointClass · 0.85

Tested by

no test coverage detected