(
self,
hidden_states: torch.Tensor,
attention_mask: Optional[torch.Tensor] = None,
position_ids: Optional[torch.LongTensor] = None,
past_key_value: Optional[Cache] = None,
output_attentions: bool = False,
use_cache: bool = False,
)
| 300 | raise ValueError(f"Unknown RoPE scaling type {scaling_type}") |
| 301 | |
| 302 | def forward( |
| 303 | self, |
| 304 | hidden_states: torch.Tensor, |
| 305 | attention_mask: Optional[torch.Tensor] = None, |
| 306 | position_ids: Optional[torch.LongTensor] = None, |
| 307 | past_key_value: Optional[Cache] = None, |
| 308 | output_attentions: bool = False, |
| 309 | use_cache: bool = False, |
| 310 | ) -> Tuple[torch.Tensor, Optional[torch.Tensor], Optional[Tuple[torch.Tensor]]]: |
| 311 | bsz, q_len, _ = hidden_states.size() |
| 312 | |
| 313 | query_states = self.q_proj(hidden_states) |
| 314 | key_states = self.k_proj(hidden_states) |
| 315 | value_states = self.v_proj(hidden_states) |
| 316 | |
| 317 | query_states = query_states.view(bsz, q_len, self.num_heads, self.head_dim).transpose(1, 2) |
| 318 | key_states = key_states.view(bsz, q_len, self.num_key_value_heads, self.head_dim).transpose(1, 2) |
| 319 | value_states = value_states.view(bsz, q_len, self.num_key_value_heads, self.head_dim).transpose(1, 2) |
| 320 | |
| 321 | if self.qk_layernorm: |
| 322 | query_states = self.q_layernorm(query_states) |
| 323 | key_states = self.k_layernorm(key_states) |
| 324 | |
| 325 | kv_seq_len = key_states.shape[-2] |
| 326 | if past_key_value is not None: |
| 327 | if self.layer_idx is None: |
| 328 | raise ValueError( |
| 329 | f"The cache structure has changed since version v4.36. If you are using {self.__class__.__name__} " |
| 330 | "for auto-regressive decoding with k/v caching, please make sure to initialize the attention class " |
| 331 | "with a layer index." |
| 332 | ) |
| 333 | kv_seq_len += past_key_value.get_usable_length(kv_seq_len, self.layer_idx) |
| 334 | cos, sin = self.rotary_emb(value_states, seq_len=kv_seq_len) |
| 335 | |
| 336 | # Partial rotary embedding |
| 337 | query_rot, query_pass = ( |
| 338 | query_states[..., : self.rotary_emb.dim], |
| 339 | query_states[..., self.rotary_emb.dim :], |
| 340 | ) |
| 341 | key_rot, key_pass = ( |
| 342 | key_states[..., : self.rotary_emb.dim], |
| 343 | key_states[..., self.rotary_emb.dim :], |
| 344 | ) |
| 345 | # [batch_size, seq_length, num_heads, head_dim // config.partial_rotary_factor] |
| 346 | query_rot, key_rot = apply_rotary_pos_emb(query_rot, key_rot, cos, sin, position_ids) |
| 347 | |
| 348 | # [batch_size, seq_length, num_heads, head_dim] |
| 349 | query_states = torch.cat((query_rot, query_pass), dim=-1) |
| 350 | key_states = torch.cat((key_rot, key_pass), dim=-1) |
| 351 | |
| 352 | if past_key_value is not None: |
| 353 | cache_kwargs = {"sin": sin, "cos": cos, "partial_rotation_size": self.rotary_emb.dim} |
| 354 | key_states, value_states = past_key_value.update(key_states, value_states, self.layer_idx, cache_kwargs) |
| 355 | |
| 356 | key_states = repeat_kv(key_states, self.num_key_value_groups) |
| 357 | value_states = repeat_kv(value_states, self.num_key_value_groups) |
| 358 | |
| 359 | # Queries and keys upcast to fp32 is required by Phi-2 to avoid overflow |
nothing calls this directly
no test coverage detected