(self, indices: torch.Tensor)
| 20 | self.freqs = rope_freq[0] / (rope_freq[1] ** (self.freqs)) |
| 21 | |
| 22 | def _get_phases(self, indices: torch.Tensor) -> torch.Tensor: |
| 23 | self.freqs = self.freqs.to(indices.device) |
| 24 | phases = torch.outer(indices, self.freqs) |
| 25 | phases = torch.polar(torch.ones_like(phases), phases) |
| 26 | return phases |
| 27 | |
| 28 | @staticmethod |
| 29 | def apply_rotary_embedding(x: torch.Tensor, phases: torch.Tensor) -> torch.Tensor: |