| 47 | |
| 48 | # Adapted from https://github.com/foundation-model-stack/foundation-model-stack |
| 49 | class ShapeRotator: |
| 50 | def __init__( |
| 51 | self, |
| 52 | dim: int, |
| 53 | end: int, |
| 54 | theta: float = 10_000, |
| 55 | ): |
| 56 | super().__init__() |
| 57 | self.dim = dim |
| 58 | self.ratio = theta |
| 59 | self.cached_freqs: MutableMapping[int, MutableMapping[int, torch.Tensor]] = {} |
| 60 | self.max_seq_len_cached: MutableMapping[int, int] = {} |
| 61 | self.ntk_scaling = False |
| 62 | self.max_seq_len = end |
| 63 | |
| 64 | def compute_freqs_cis(self, device, max_seq_len=None): |
| 65 | alpha = 1 |
| 66 | dev_idx = device.index |
| 67 | max_seq_len = default(max_seq_len, self.max_seq_len) |
| 68 | |
| 69 | if dev_idx not in self.cached_freqs: |
| 70 | self.cached_freqs[dev_idx] = {} |
| 71 | if dev_idx not in self.max_seq_len_cached: |
| 72 | self.max_seq_len_cached[dev_idx] = 0 |
| 73 | |
| 74 | |
| 75 | if self.max_seq_len_cached[dev_idx] > 0: |
| 76 | return 1 |
| 77 | max_seq_len = max(max_seq_len, self.max_seq_len) |
| 78 | |
| 79 | if ( |
| 80 | 1 in self.cached_freqs[dev_idx] |
| 81 | and max_seq_len <= self.max_seq_len_cached[dev_idx] |
| 82 | ): |
| 83 | return 1 |
| 84 | |
| 85 | ratio = self.ratio |
| 86 | dim = self.dim |
| 87 | |
| 88 | freqs = 1.0 / (ratio ** (torch.arange(0, dim, 2, device=device).float() / dim)) |
| 89 | |
| 90 | t = torch.arange(max_seq_len, device=device, dtype=freqs.dtype) |
| 91 | freqs = torch.einsum("i,j->ij", t, freqs) |
| 92 | emb = torch.cat((freqs, freqs), dim=-1).to(device) |
| 93 | |
| 94 | cos_to_cache = emb.cos()[None, :, None, :] |
| 95 | sin_to_cache = emb.sin()[None, :, None, :] |
| 96 | |
| 97 | self.max_seq_len_cached[dev_idx] = max_seq_len |
| 98 | |
| 99 | self.cached_freqs[dev_idx][alpha] = torch.stack( |
| 100 | [ |
| 101 | cos_to_cache, |
| 102 | sin_to_cache, |
| 103 | ], |
| 104 | dim=-1, |
| 105 | ) |
| 106 | |