(self, point)
| 438 | return point[pad_key], point[unpad_key], point[cu_seqlens_key] |
| 439 | |
| 440 | def forward(self, point): |
| 441 | if not self.enable_flash: |
| 442 | self.patch_size = min( |
| 443 | offset2bincount(point.offset).min().tolist(), self.patch_size_max |
| 444 | ) |
| 445 | |
| 446 | H = self.num_heads |
| 447 | K = self.patch_size |
| 448 | C = self.channels |
| 449 | |
| 450 | pad, unpad, cu_seqlens = self.get_padding_and_inverse(point) |
| 451 | |
| 452 | order = point.serialized_order[self.order_index][pad] |
| 453 | inverse = unpad[point.serialized_inverse[self.order_index]] |
| 454 | |
| 455 | # padding and reshape feat and batch for serialized point patch |
| 456 | qkv = self.qkv(point.feat)[order] |
| 457 | |
| 458 | if not self.enable_flash: |
| 459 | # encode and reshape qkv: (N', K, 3, H, C') => (3, N', H, K, C') |
| 460 | q, k, v = ( |
| 461 | qkv.reshape(-1, K, 3, H, C // H).permute(2, 0, 3, 1, 4).unbind(dim=0) |
| 462 | ) |
| 463 | # attn |
| 464 | if self.upcast_attention: |
| 465 | q = q.float() |
| 466 | k = k.float() |
| 467 | attn = (q * self.scale) @ k.transpose(-2, -1) # (N', H, K, K) |
| 468 | if self.enable_rpe: |
| 469 | attn = attn + self.rpe(self.get_rel_pos(point, order)) |
| 470 | if self.upcast_softmax: |
| 471 | attn = attn.float() |
| 472 | attn = self.softmax(attn) |
| 473 | attn = self.attn_drop(attn).to(qkv.dtype) |
| 474 | feat = (attn @ v).transpose(1, 2).reshape(-1, C) |
| 475 | else: |
| 476 | feat = flash_attn.flash_attn_varlen_qkvpacked_func( |
| 477 | qkv.half().reshape(-1, 3, H, C // H), |
| 478 | cu_seqlens, |
| 479 | max_seqlen=self.patch_size, |
| 480 | dropout_p=self.attn_drop if self.training else 0, |
| 481 | softmax_scale=self.scale, |
| 482 | ).reshape(-1, C) |
| 483 | feat = feat.to(qkv.dtype) |
| 484 | feat = feat[inverse] |
| 485 | |
| 486 | # ffn |
| 487 | feat = self.proj(feat) |
| 488 | feat = self.proj_drop(feat) |
| 489 | point.feat = feat |
| 490 | return point |
| 491 | |
| 492 | |
| 493 | class MLP(nn.Module): |
nothing calls this directly
no test coverage detected