(
self,
hidden_states: torch.Tensor,
attention_mask: Optional[torch.Tensor] = None,
position_ids: Optional[torch.LongTensor] = None,
past_key_value: Optional[Cache] = None,
output_attentions: bool = False,
use_cache: bool = False,
cache_position: Optional[torch.LongTensor] = None,
position_embeddings: Optional[Tuple[torch.Tensor, torch.Tensor]] = None, # will become mandatory in v4.45
**kwargs,
)
| 314 | self.rotary_emb = LlamaRotaryEmbedding(config=self.config) |
| 315 | |
| 316 | def forward( |
| 317 | self, |
| 318 | hidden_states: torch.Tensor, |
| 319 | attention_mask: Optional[torch.Tensor] = None, |
| 320 | position_ids: Optional[torch.LongTensor] = None, |
| 321 | past_key_value: Optional[Cache] = None, |
| 322 | output_attentions: bool = False, |
| 323 | use_cache: bool = False, |
| 324 | cache_position: Optional[torch.LongTensor] = None, |
| 325 | position_embeddings: Optional[Tuple[torch.Tensor, torch.Tensor]] = None, # will become mandatory in v4.45 |
| 326 | **kwargs, |
| 327 | ) -> Tuple[torch.Tensor, Optional[torch.Tensor], Optional[Tuple[torch.Tensor]]]: |
| 328 | bsz, q_len, _ = hidden_states.size() |
| 329 | |
| 330 | if self.config.pretraining_tp > 1: |
| 331 | key_value_slicing = (self.num_key_value_heads * self.head_dim) // self.config.pretraining_tp |
| 332 | query_slices = self.q_proj.weight.split( |
| 333 | (self.num_heads * self.head_dim) // self.config.pretraining_tp, dim=0 |
| 334 | ) |
| 335 | key_slices = self.k_proj.weight.split(key_value_slicing, dim=0) |
| 336 | value_slices = self.v_proj.weight.split(key_value_slicing, dim=0) |
| 337 | |
| 338 | query_states = [F.linear(hidden_states, query_slices[i]) for i in range(self.config.pretraining_tp)] |
| 339 | query_states = torch.cat(query_states, dim=-1) |
| 340 | |
| 341 | key_states = [F.linear(hidden_states, key_slices[i]) for i in range(self.config.pretraining_tp)] |
| 342 | key_states = torch.cat(key_states, dim=-1) |
| 343 | |
| 344 | value_states = [F.linear(hidden_states, value_slices[i]) for i in range(self.config.pretraining_tp)] |
| 345 | value_states = torch.cat(value_states, dim=-1) |
| 346 | |
| 347 | else: |
| 348 | query_states = self.q_proj(hidden_states) |
| 349 | key_states = self.k_proj(hidden_states) |
| 350 | value_states = self.v_proj(hidden_states) |
| 351 | |
| 352 | query_states = query_states.view(bsz, q_len, self.num_heads, self.head_dim).transpose(1, 2) |
| 353 | key_states = key_states.view(bsz, q_len, self.num_key_value_heads, self.head_dim).transpose(1, 2) |
| 354 | value_states = value_states.view(bsz, q_len, self.num_key_value_heads, self.head_dim).transpose(1, 2) |
| 355 | |
| 356 | if position_embeddings is None: |
| 357 | logger.warning_once( |
| 358 | "The attention layers in this model are transitioning from computing the RoPE embeddings internally " |
| 359 | "through `position_ids` (2D tensor with the indexes of the tokens), to using externally computed " |
| 360 | "`position_embeddings` (Tuple of tensors, containing cos and sin). In v4.45 `position_ids` will be " |
| 361 | "removed and `position_embeddings` will be mandatory." |
| 362 | ) |
| 363 | cos, sin = self.rotary_emb(value_states, position_ids) |
| 364 | else: |
| 365 | cos, sin = position_embeddings |
| 366 | query_states, key_states = apply_rotary_pos_emb(query_states, key_states, cos, sin) |
| 367 | |
| 368 | if past_key_value is not None: |
| 369 | # sin and cos are specific to RoPE models; cache_position needed for the static cache |
| 370 | cache_kwargs = {"sin": sin, "cos": cos, "cache_position": cache_position} |
| 371 | key_states, value_states = past_key_value.update(key_states, value_states, self.layer_idx, cache_kwargs) |
| 372 | |
| 373 | key_states = repeat_kv(key_states, self.num_key_value_groups) |
nothing calls this directly
no test coverage detected