(
self,
hidden_states: torch.Tensor,
attention_mask: Optional[torch.Tensor] = None,
position_ids: Optional[torch.LongTensor] = None,
past_key_value: Optional[Tuple[torch.Tensor]] = None,
output_attentions: bool = False,
use_cache: bool = False,
padding_mask: Optional[torch.LongTensor] = None,
)
| 28 | |
| 29 | |
| 30 | def forward( |
| 31 | self, |
| 32 | hidden_states: torch.Tensor, |
| 33 | attention_mask: Optional[torch.Tensor] = None, |
| 34 | position_ids: Optional[torch.LongTensor] = None, |
| 35 | past_key_value: Optional[Tuple[torch.Tensor]] = None, |
| 36 | output_attentions: bool = False, |
| 37 | use_cache: bool = False, |
| 38 | padding_mask: Optional[torch.LongTensor] = None, |
| 39 | ) -> Tuple[torch.Tensor, Optional[torch.Tensor], Optional[Tuple[torch.Tensor]]]: |
| 40 | bsz, q_len, _ = hidden_states.size() |
| 41 | |
| 42 | query_states = ( |
| 43 | self.q_proj(hidden_states) |
| 44 | .view(bsz, q_len, self.num_heads, self.head_dim) |
| 45 | .transpose(1, 2) |
| 46 | ) |
| 47 | key_states = ( |
| 48 | self.k_proj(hidden_states) |
| 49 | .view(bsz, q_len, self.num_heads, self.head_dim) |
| 50 | .transpose(1, 2) |
| 51 | ) |
| 52 | value_states = ( |
| 53 | self.v_proj(hidden_states) |
| 54 | .view(bsz, q_len, self.num_heads, self.head_dim) |
| 55 | .transpose(1, 2) |
| 56 | ) |
| 57 | |
| 58 | kv_seq_len = key_states.shape[-2] |
| 59 | if past_key_value is not None: |
| 60 | kv_seq_len += past_key_value[0].shape[-2] |
| 61 | cos, sin = self.rotary_emb(value_states, seq_len=kv_seq_len) |
| 62 | query_states, key_states = apply_rotary_pos_emb( |
| 63 | query_states, key_states, cos, sin, position_ids |
| 64 | ) |
| 65 | # [bsz, nh, t, hd] |
| 66 | |
| 67 | if past_key_value is not None: |
| 68 | # reuse k, v, self_attention |
| 69 | key_states = torch.cat([past_key_value[0], key_states], dim=2) |
| 70 | value_states = torch.cat([past_key_value[1], value_states], dim=2) |
| 71 | |
| 72 | past_key_value = (key_states, value_states) if use_cache else None |
| 73 | |
| 74 | attn_weights = torch.matmul(query_states, key_states.transpose(2, 3)) / math.sqrt( |
| 75 | self.head_dim |
| 76 | ) |
| 77 | |
| 78 | if attn_weights.size() != (bsz, self.num_heads, q_len, kv_seq_len): |
| 79 | raise ValueError( |
| 80 | f"Attention weights should be of size {(bsz * self.num_heads, q_len, kv_seq_len)}, but is" |
| 81 | f" {attn_weights.size()}" |
| 82 | ) |
| 83 | |
| 84 | if attention_mask is not None: |
| 85 | if attention_mask.size() != (bsz, 1, q_len, kv_seq_len): |
| 86 | raise ValueError( |
| 87 | f"Attention mask should be of size {(bsz, 1, q_len, kv_seq_len)}, but is {attention_mask.size()}" |
nothing calls this directly
no test coverage detected
searching dependent graphs…