Applies Rotary Position Embedding to the query and key tensors. Args: q (`torch.Tensor`): The query tensor. k (`torch.Tensor`): The key tensor. cos (`torch.Tensor`): The cosine part of the rotary embedding. sin (`torch.Tensor`): The sine part of the rotary embedd
(q, k, cos, sin, position_ids=None, unsqueeze_dim=1)
| 203 | |
| 204 | |
| 205 | def apply_rotary_pos_emb(q, k, cos, sin, position_ids=None, unsqueeze_dim=1): |
| 206 | """Applies Rotary Position Embedding to the query and key tensors. |
| 207 | |
| 208 | Args: |
| 209 | q (`torch.Tensor`): The query tensor. |
| 210 | k (`torch.Tensor`): The key tensor. |
| 211 | cos (`torch.Tensor`): The cosine part of the rotary embedding. |
| 212 | sin (`torch.Tensor`): The sine part of the rotary embedding. |
| 213 | position_ids (`torch.Tensor`, *optional*): |
| 214 | Deprecated and unused. |
| 215 | unsqueeze_dim (`int`, *optional*, defaults to 1): |
| 216 | The 'unsqueeze_dim' argument specifies the dimension along which to unsqueeze cos[position_ids] and |
| 217 | sin[position_ids] so that they can be properly broadcasted to the dimensions of q and k. For example, note |
| 218 | that cos[position_ids] and sin[position_ids] have the shape [batch_size, seq_len, head_dim]. Then, if q and |
| 219 | k have the shape [batch_size, heads, seq_len, head_dim], then setting unsqueeze_dim=1 makes |
| 220 | cos[position_ids] and sin[position_ids] broadcastable to the shapes of q and k. Similarly, if q and k have |
| 221 | the shape [batch_size, seq_len, heads, head_dim], then set unsqueeze_dim=2. |
| 222 | Returns: |
| 223 | `tuple(torch.Tensor)` comprising of the query and key tensors rotated using the Rotary Position Embedding. |
| 224 | """ |
| 225 | cos = cos.unsqueeze(unsqueeze_dim) |
| 226 | sin = sin.unsqueeze(unsqueeze_dim) |
| 227 | q_embed = (q * cos) + (rotate_half(q) * sin) |
| 228 | k_embed = (k * cos) + (rotate_half(k) * sin) |
| 229 | return q_embed, k_embed |
| 230 | |
| 231 | |
| 232 | class LlamaMLP(nn.Module): |
no test coverage detected