(
self,
in_channels=6,
order=("z", "z-trans", "hilbert", "hilbert-trans"),
stride=(2, 2, 2, 2),
enc_depths=(2, 2, 2, 6, 2),
enc_channels=(32, 64, 128, 256, 512),
enc_num_head=(2, 4, 8, 16, 32),
enc_patch_size=(1024, 1024, 1024, 1024, 1024),
dec_depths=(2, 2, 2, 2),
dec_channels=(64, 64, 128, 256),
dec_num_head=(4, 4, 8, 16),
dec_patch_size=(1024, 1024, 1024, 1024),
mlp_ratio=4,
qkv_bias=True,
qk_scale=None,
attn_drop=0.0,
proj_drop=0.0,
drop_path=0.3,
pre_norm=True,
shuffle_orders=True,
enable_rpe=False,
enable_flash=True,
upcast_attention=False,
upcast_softmax=False,
cls_mode=False,
pdnorm_bn=False,
pdnorm_ln=False,
pdnorm_decouple=True,
pdnorm_adaptive=False,
pdnorm_affine=True,
pdnorm_conditions=("ScanNet", "S3DIS", "Structured3D"),
)
| 785 | |
| 786 | class PointTransformerV3(PointModule): |
| 787 | def __init__( |
| 788 | self, |
| 789 | in_channels=6, |
| 790 | order=("z", "z-trans", "hilbert", "hilbert-trans"), |
| 791 | stride=(2, 2, 2, 2), |
| 792 | enc_depths=(2, 2, 2, 6, 2), |
| 793 | enc_channels=(32, 64, 128, 256, 512), |
| 794 | enc_num_head=(2, 4, 8, 16, 32), |
| 795 | enc_patch_size=(1024, 1024, 1024, 1024, 1024), |
| 796 | dec_depths=(2, 2, 2, 2), |
| 797 | dec_channels=(64, 64, 128, 256), |
| 798 | dec_num_head=(4, 4, 8, 16), |
| 799 | dec_patch_size=(1024, 1024, 1024, 1024), |
| 800 | mlp_ratio=4, |
| 801 | qkv_bias=True, |
| 802 | qk_scale=None, |
| 803 | attn_drop=0.0, |
| 804 | proj_drop=0.0, |
| 805 | drop_path=0.3, |
| 806 | pre_norm=True, |
| 807 | shuffle_orders=True, |
| 808 | enable_rpe=False, |
| 809 | enable_flash=True, |
| 810 | upcast_attention=False, |
| 811 | upcast_softmax=False, |
| 812 | cls_mode=False, |
| 813 | pdnorm_bn=False, |
| 814 | pdnorm_ln=False, |
| 815 | pdnorm_decouple=True, |
| 816 | pdnorm_adaptive=False, |
| 817 | pdnorm_affine=True, |
| 818 | pdnorm_conditions=("ScanNet", "S3DIS", "Structured3D"), |
| 819 | ): |
| 820 | super().__init__() |
| 821 | self.num_stages = len(enc_depths) |
| 822 | self.order = [order] if isinstance(order, str) else order |
| 823 | self.cls_mode = cls_mode |
| 824 | self.shuffle_orders = shuffle_orders |
| 825 | |
| 826 | assert self.num_stages == len(stride) + 1 |
| 827 | assert self.num_stages == len(enc_depths) |
| 828 | assert self.num_stages == len(enc_channels) |
| 829 | assert self.num_stages == len(enc_num_head) |
| 830 | assert self.num_stages == len(enc_patch_size) |
| 831 | assert self.cls_mode or self.num_stages == len(dec_depths) + 1 |
| 832 | assert self.cls_mode or self.num_stages == len(dec_channels) + 1 |
| 833 | assert self.cls_mode or self.num_stages == len(dec_num_head) + 1 |
| 834 | assert self.cls_mode or self.num_stages == len(dec_patch_size) + 1 |
| 835 | |
| 836 | # norm layers |
| 837 | if pdnorm_bn: |
| 838 | bn_layer = partial( |
| 839 | PDNorm, |
| 840 | norm_layer=partial( |
| 841 | nn.BatchNorm1d, eps=1e-3, momentum=0.01, affine=pdnorm_affine |
| 842 | ), |
| 843 | conditions=pdnorm_conditions, |
| 844 | decouple=pdnorm_decouple, |
nothing calls this directly
no test coverage detected