MCPcopy
hub / github.com/THUDM/LongWriter / apply_rotary_pos_emb

Function apply_rotary_pos_emb

train/patch/modeling_llama.py:205–229  ·  view source on GitHub ↗

Applies Rotary Position Embedding to the query and key tensors. Args: q (`torch.Tensor`): The query tensor. k (`torch.Tensor`): The key tensor. cos (`torch.Tensor`): The cosine part of the rotary embedding. sin (`torch.Tensor`): The sine part of the rotary embedd

(q, k, cos, sin, position_ids=None, unsqueeze_dim=1)

Source from the content-addressed store, hash-verified

203
204
205def apply_rotary_pos_emb(q, k, cos, sin, position_ids=None, unsqueeze_dim=1):
206 """Applies Rotary Position Embedding to the query and key tensors.
207
208 Args:
209 q (`torch.Tensor`): The query tensor.
210 k (`torch.Tensor`): The key tensor.
211 cos (`torch.Tensor`): The cosine part of the rotary embedding.
212 sin (`torch.Tensor`): The sine part of the rotary embedding.
213 position_ids (`torch.Tensor`, *optional*):
214 Deprecated and unused.
215 unsqueeze_dim (`int`, *optional*, defaults to 1):
216 The 'unsqueeze_dim' argument specifies the dimension along which to unsqueeze cos[position_ids] and
217 sin[position_ids] so that they can be properly broadcasted to the dimensions of q and k. For example, note
218 that cos[position_ids] and sin[position_ids] have the shape [batch_size, seq_len, head_dim]. Then, if q and
219 k have the shape [batch_size, heads, seq_len, head_dim], then setting unsqueeze_dim=1 makes
220 cos[position_ids] and sin[position_ids] broadcastable to the shapes of q and k. Similarly, if q and k have
221 the shape [batch_size, seq_len, heads, head_dim], then set unsqueeze_dim=2.
222 Returns:
223 `tuple(torch.Tensor)` comprising of the query and key tensors rotated using the Rotary Position Embedding.
224 """
225 cos = cos.unsqueeze(unsqueeze_dim)
226 sin = sin.unsqueeze(unsqueeze_dim)
227 q_embed = (q * cos) + (rotate_half(q) * sin)
228 k_embed = (k * cos) + (rotate_half(k) * sin)
229 return q_embed, k_embed
230
231
232class LlamaMLP(nn.Module):

Callers 4

forwardMethod · 0.70
forwardMethod · 0.70
forwardMethod · 0.70
forwardMethod · 0.70

Calls 1

rotate_halfFunction · 0.85

Tested by

no test coverage detected