Rotates half the hidden dims of the input.
(x)
| 51 | |
| 52 | |
| 53 | def rotate_half(x): |
| 54 | """Rotates half the hidden dims of the input.""" |
| 55 | x1 = x[..., : x.shape[-1] // 2] |
| 56 | x2 = x[..., x.shape[-1] // 2 :] |
| 57 | return torch.cat((-x2, x1), dim=-1) |
| 58 | |
| 59 | |
| 60 | def apply_rotary_pos_emb(q, k, cos, sin, unsqueeze_dim=2): |
no outgoing calls
no test coverage detected