Args: indices (torch.Tensor): [..., N, C] tensor of spatial positions
(self, indices: torch.Tensor)
| 33 | return x_embed |
| 34 | |
| 35 | def forward(self, indices: torch.Tensor) -> torch.Tensor: |
| 36 | """ |
| 37 | Args: |
| 38 | indices (torch.Tensor): [..., N, C] tensor of spatial positions |
| 39 | """ |
| 40 | assert indices.shape[-1] == self.dim, f"Last dim of indices must be {self.dim}" |
| 41 | phases = self._get_phases(indices.reshape(-1)).reshape(*indices.shape[:-1], -1) |
| 42 | if phases.shape[-1] < self.head_dim // 2: |
| 43 | padn = self.head_dim // 2 - phases.shape[-1] |
| 44 | phases = torch.cat([phases, torch.polar( |
| 45 | torch.ones(*phases.shape[:-1], padn, device=phases.device), |
| 46 | torch.zeros(*phases.shape[:-1], padn, device=phases.device) |
| 47 | )], dim=-1) |
| 48 | return phases |
nothing calls this directly
no test coverage detected