MCPcopy Index your code
hub / github.com/Pointcept/PointTransformerV3 / RPE

Class RPE

model.py:297–316  ·  view source on GitHub ↗

Source from the content-addressed store, hash-verified

295
296
297class RPE(torch.nn.Module):
298 def __init__(self, patch_size, num_heads):
299 super().__init__()
300 self.patch_size = patch_size
301 self.num_heads = num_heads
302 self.pos_bnd = int((4 * patch_size) ** (1 / 3) * 2)
303 self.rpe_num = 2 * self.pos_bnd + 1
304 self.rpe_table = torch.nn.Parameter(torch.zeros(3 * self.rpe_num, num_heads))
305 torch.nn.init.trunc_normal_(self.rpe_table, std=0.02)
306
307 def forward(self, coord):
308 idx = (
309 coord.clamp(-self.pos_bnd, self.pos_bnd) # clamp into bnd
310 + self.pos_bnd # relative position to positive index
311 + torch.arange(3, device=coord.device) * self.rpe_num # x, y, z stride
312 )
313 out = self.rpe_table.index_select(0, idx.reshape(-1))
314 out = out.view(idx.shape + (-1,)).sum(3)
315 out = out.permute(0, 3, 1, 2) # (N, K, K, H) -> (N, H, K, K)
316 return out
317
318
319class SerializedAttention(PointModule):

Callers 1

__init__Method · 0.85

Calls

no outgoing calls

Tested by

no test coverage detected