(
self,
channels,
num_heads,
patch_size=48,
mlp_ratio=4.0,
qkv_bias=True,
qk_scale=None,
attn_drop=0.0,
proj_drop=0.0,
drop_path=0.0,
norm_layer=nn.LayerNorm,
act_layer=nn.GELU,
pre_norm=True,
order_index=0,
cpe_indice_key=None,
enable_rpe=False,
enable_flash=True,
upcast_attention=True,
upcast_softmax=True,
)
| 518 | |
| 519 | class Block(PointModule): |
| 520 | def __init__( |
| 521 | self, |
| 522 | channels, |
| 523 | num_heads, |
| 524 | patch_size=48, |
| 525 | mlp_ratio=4.0, |
| 526 | qkv_bias=True, |
| 527 | qk_scale=None, |
| 528 | attn_drop=0.0, |
| 529 | proj_drop=0.0, |
| 530 | drop_path=0.0, |
| 531 | norm_layer=nn.LayerNorm, |
| 532 | act_layer=nn.GELU, |
| 533 | pre_norm=True, |
| 534 | order_index=0, |
| 535 | cpe_indice_key=None, |
| 536 | enable_rpe=False, |
| 537 | enable_flash=True, |
| 538 | upcast_attention=True, |
| 539 | upcast_softmax=True, |
| 540 | ): |
| 541 | super().__init__() |
| 542 | self.channels = channels |
| 543 | self.pre_norm = pre_norm |
| 544 | |
| 545 | self.cpe = PointSequential( |
| 546 | spconv.SubMConv3d( |
| 547 | channels, |
| 548 | channels, |
| 549 | kernel_size=3, |
| 550 | bias=True, |
| 551 | indice_key=cpe_indice_key, |
| 552 | ), |
| 553 | nn.Linear(channels, channels), |
| 554 | norm_layer(channels), |
| 555 | ) |
| 556 | |
| 557 | self.norm1 = PointSequential(norm_layer(channels)) |
| 558 | self.attn = SerializedAttention( |
| 559 | channels=channels, |
| 560 | patch_size=patch_size, |
| 561 | num_heads=num_heads, |
| 562 | qkv_bias=qkv_bias, |
| 563 | qk_scale=qk_scale, |
| 564 | attn_drop=attn_drop, |
| 565 | proj_drop=proj_drop, |
| 566 | order_index=order_index, |
| 567 | enable_rpe=enable_rpe, |
| 568 | enable_flash=enable_flash, |
| 569 | upcast_attention=upcast_attention, |
| 570 | upcast_softmax=upcast_softmax, |
| 571 | ) |
| 572 | self.norm2 = PointSequential(norm_layer(channels)) |
| 573 | self.mlp = PointSequential( |
| 574 | MLP( |
| 575 | in_channels=channels, |
| 576 | hidden_channels=int(channels * mlp_ratio), |
| 577 | out_channels=channels, |
nothing calls this directly
no test coverage detected