MCPcopy Index your code
hub / github.com/Standard-Intelligence/hertz-dev / ShapeRotator

Class ShapeRotator

transformer.py:49–134  ·  view source on GitHub ↗

Source from the content-addressed store, hash-verified

47
48# Adapted from https://github.com/foundation-model-stack/foundation-model-stack
49class ShapeRotator:
50 def __init__(
51 self,
52 dim: int,
53 end: int,
54 theta: float = 10_000,
55 ):
56 super().__init__()
57 self.dim = dim
58 self.ratio = theta
59 self.cached_freqs: MutableMapping[int, MutableMapping[int, torch.Tensor]] = {}
60 self.max_seq_len_cached: MutableMapping[int, int] = {}
61 self.ntk_scaling = False
62 self.max_seq_len = end
63
64 def compute_freqs_cis(self, device, max_seq_len=None):
65 alpha = 1
66 dev_idx = device.index
67 max_seq_len = default(max_seq_len, self.max_seq_len)
68
69 if dev_idx not in self.cached_freqs:
70 self.cached_freqs[dev_idx] = {}
71 if dev_idx not in self.max_seq_len_cached:
72 self.max_seq_len_cached[dev_idx] = 0
73
74
75 if self.max_seq_len_cached[dev_idx] > 0:
76 return 1
77 max_seq_len = max(max_seq_len, self.max_seq_len)
78
79 if (
80 1 in self.cached_freqs[dev_idx]
81 and max_seq_len <= self.max_seq_len_cached[dev_idx]
82 ):
83 return 1
84
85 ratio = self.ratio
86 dim = self.dim
87
88 freqs = 1.0 / (ratio ** (torch.arange(0, dim, 2, device=device).float() / dim))
89
90 t = torch.arange(max_seq_len, device=device, dtype=freqs.dtype)
91 freqs = torch.einsum("i,j->ij", t, freqs)
92 emb = torch.cat((freqs, freqs), dim=-1).to(device)
93
94 cos_to_cache = emb.cos()[None, :, None, :]
95 sin_to_cache = emb.sin()[None, :, None, :]
96
97 self.max_seq_len_cached[dev_idx] = max_seq_len
98
99 self.cached_freqs[dev_idx][alpha] = torch.stack(
100 [
101 cos_to_cache,
102 sin_to_cache,
103 ],
104 dim=-1,
105 )
106

Callers 2

__init__Method · 0.90
__init__Method · 0.85

Calls

no outgoing calls

Tested by

no test coverage detected