(q, k, cos, sin, position_ids)
| 18 | |
| 19 | |
| 20 | def apply_rotary_pos_emb(q, k, cos, sin, position_ids): |
| 21 | gather_indices = position_ids[:, None, :, None] # [bs, 1, seq_len, 1] |
| 22 | gather_indices = gather_indices.repeat(1, cos.shape[1], 1, cos.shape[3]) |
| 23 | cos = torch.gather(cos.repeat(gather_indices.shape[0], 1, 1, 1), 2, gather_indices) |
| 24 | sin = torch.gather(sin.repeat(gather_indices.shape[0], 1, 1, 1), 2, gather_indices) |
| 25 | q_embed = (q * cos) + (rotate_half(q) * sin) |
| 26 | k_embed = (k * cos) + (rotate_half(k) * sin) |
| 27 | return q_embed, k_embed |
| 28 | |
| 29 | |
| 30 | def forward( |
no test coverage detected
searching dependent graphs…