(self, t)
| 133 | print('======== shape of rope freq', self.freqs_cos.shape, '========') |
| 134 | |
| 135 | def forward(self, t): |
| 136 | if t.shape[1] % 2 != 0: |
| 137 | t_spatial = t[:, 1:, :] |
| 138 | t_spatial = t_spatial * self.freqs_cos + rotate_half(t_spatial) * self.freqs_sin |
| 139 | return torch.cat((t[:, :1, :], t_spatial), dim=1) |
| 140 | else: |
| 141 | return t * self.freqs_cos + rotate_half(t) * self.freqs_sin |
nothing calls this directly
no test coverage detected