| 295 | |
| 296 | |
| 297 | class RPE(torch.nn.Module): |
| 298 | def __init__(self, patch_size, num_heads): |
| 299 | super().__init__() |
| 300 | self.patch_size = patch_size |
| 301 | self.num_heads = num_heads |
| 302 | self.pos_bnd = int((4 * patch_size) ** (1 / 3) * 2) |
| 303 | self.rpe_num = 2 * self.pos_bnd + 1 |
| 304 | self.rpe_table = torch.nn.Parameter(torch.zeros(3 * self.rpe_num, num_heads)) |
| 305 | torch.nn.init.trunc_normal_(self.rpe_table, std=0.02) |
| 306 | |
| 307 | def forward(self, coord): |
| 308 | idx = ( |
| 309 | coord.clamp(-self.pos_bnd, self.pos_bnd) # clamp into bnd |
| 310 | + self.pos_bnd # relative position to positive index |
| 311 | + torch.arange(3, device=coord.device) * self.rpe_num # x, y, z stride |
| 312 | ) |
| 313 | out = self.rpe_table.index_select(0, idx.reshape(-1)) |
| 314 | out = out.view(idx.shape + (-1,)).sum(3) |
| 315 | out = out.permute(0, 3, 1, 2) # (N, K, K, H) -> (N, H, K, K) |
| 316 | return out |
| 317 | |
| 318 | |
| 319 | class SerializedAttention(PointModule): |