(self, patch_size, num_heads)
| 296 | |
| 297 | class RPE(torch.nn.Module): |
| 298 | def __init__(self, patch_size, num_heads): |
| 299 | super().__init__() |
| 300 | self.patch_size = patch_size |
| 301 | self.num_heads = num_heads |
| 302 | self.pos_bnd = int((4 * patch_size) ** (1 / 3) * 2) |
| 303 | self.rpe_num = 2 * self.pos_bnd + 1 |
| 304 | self.rpe_table = torch.nn.Parameter(torch.zeros(3 * self.rpe_num, num_heads)) |
| 305 | torch.nn.init.trunc_normal_(self.rpe_table, std=0.02) |
| 306 | |
| 307 | def forward(self, coord): |
| 308 | idx = ( |