| 318 | |
| 319 | class SerializedAttention(PointModule): |
| 320 | def __init__( |
| 321 | self, |
| 322 | channels, |
| 323 | num_heads, |
| 324 | patch_size, |
| 325 | qkv_bias=True, |
| 326 | qk_scale=None, |
| 327 | attn_drop=0.0, |
| 328 | proj_drop=0.0, |
| 329 | order_index=0, |
| 330 | enable_rpe=False, |
| 331 | enable_flash=True, |
| 332 | upcast_attention=True, |
| 333 | upcast_softmax=True, |
| 334 | ): |
| 335 | super().__init__() |
| 336 | assert channels % num_heads == 0 |
| 337 | self.channels = channels |
| 338 | self.num_heads = num_heads |
| 339 | self.scale = qk_scale or (channels // num_heads) ** -0.5 |
| 340 | self.order_index = order_index |
| 341 | self.upcast_attention = upcast_attention |
| 342 | self.upcast_softmax = upcast_softmax |
| 343 | self.enable_rpe = enable_rpe |
| 344 | self.enable_flash = enable_flash |
| 345 | if enable_flash: |
| 346 | assert ( |
| 347 | enable_rpe is False |
| 348 | ), "Set enable_rpe to False when enable Flash Attention" |
| 349 | assert ( |
| 350 | upcast_attention is False |
| 351 | ), "Set upcast_attention to False when enable Flash Attention" |
| 352 | assert ( |
| 353 | upcast_softmax is False |
| 354 | ), "Set upcast_softmax to False when enable Flash Attention" |
| 355 | assert flash_attn is not None, "Make sure flash_attn is installed." |
| 356 | self.patch_size = patch_size |
| 357 | self.attn_drop = attn_drop |
| 358 | else: |
| 359 | # when disable flash attention, we still don't want to use mask |
| 360 | # consequently, patch size will auto set to the |
| 361 | # min number of patch_size_max and number of points |
| 362 | self.patch_size_max = patch_size |
| 363 | self.patch_size = 0 |
| 364 | self.attn_drop = torch.nn.Dropout(attn_drop) |
| 365 | |
| 366 | self.qkv = torch.nn.Linear(channels, channels * 3, bias=qkv_bias) |
| 367 | self.proj = torch.nn.Linear(channels, channels) |
| 368 | self.proj_drop = torch.nn.Dropout(proj_drop) |
| 369 | self.softmax = torch.nn.Softmax(dim=-1) |
| 370 | self.rpe = RPE(patch_size, num_heads) if self.enable_rpe else None |
| 371 | |
| 372 | @torch.no_grad() |
| 373 | def get_rel_pos(self, point, order): |