(
self,
hidden_states: Optional[jt.Var],
attention_mask: Optional[jt.Var] = None,
layer_past: Optional[Tuple[jt.Var]] = None,
head_mask: Optional[jt.Var] = None,
use_cache: Optional[bool] = False,
)
| 102 | return attn_output, attn_weights |
| 103 | |
| 104 | def execute( |
| 105 | self, |
| 106 | hidden_states: Optional[jt.Var], |
| 107 | attention_mask: Optional[jt.Var] = None, |
| 108 | layer_past: Optional[Tuple[jt.Var]] = None, |
| 109 | head_mask: Optional[jt.Var] = None, |
| 110 | use_cache: Optional[bool] = False, |
| 111 | ) -> Union[ |
| 112 | Tuple[jt.Var, Tuple[jt.Var]], |
| 113 | Optional[Tuple[jt.Var, Tuple[jt.Var], Tuple[jt.Var, ...]]], |
| 114 | ]: |
| 115 | qkv = self.qkv_proj(hidden_states) |
| 116 | mp_num = 4 |
| 117 | qkv_split = qkv.reshape(qkv.shape[:-1] + (mp_num, -1)) |
| 118 | |
| 119 | local_dim = self.head_dim * self.num_attention_heads // mp_num |
| 120 | query, value, key = jt.split(qkv_split, local_dim, dim=-1) |
| 121 | query = self._split_heads(query, self.num_attention_heads, self.head_dim, mp_num=mp_num) |
| 122 | key = self._split_heads(key, self.num_attention_heads, self.head_dim, mp_num=mp_num) |
| 123 | |
| 124 | value = self._split_heads(value, self.num_attention_heads, self.head_dim, mp_num=mp_num) |
| 125 | value = value.permute(0, 2, 1, 3) |
| 126 | |
| 127 | seq_len = key.shape[1] |
| 128 | offset = 0 |
| 129 | |
| 130 | if layer_past is not None: |
| 131 | offset = layer_past[0].shape[-2] |
| 132 | seq_len += offset |
| 133 | |
| 134 | if self.rotary_dim is not None: |
| 135 | k_rot = key[:, :, :, : self.rotary_dim] |
| 136 | k_pass = key[:, :, :, self.rotary_dim :] |
| 137 | |
| 138 | q_rot = query[:, :, :, : self.rotary_dim] |
| 139 | q_pass = query[:, :, :, self.rotary_dim :] |
| 140 | |
| 141 | sincos = fixed_pos_embedding(k_rot, 1, seq_len=seq_len) |
| 142 | k_rot = apply_rotary_pos_emb(k_rot, sincos, offset=offset) |
| 143 | q_rot = apply_rotary_pos_emb(q_rot, sincos, offset=offset) |
| 144 | |
| 145 | key = jt.cat([k_rot, k_pass], dim=-1) |
| 146 | query = jt.cat([q_rot, q_pass], dim=-1) |
| 147 | else: |
| 148 | sincos = fixed_pos_embedding(key, 1, seq_len=seq_len) |
| 149 | key = apply_rotary_pos_emb(key, sincos, offset=offset) |
| 150 | query = apply_rotary_pos_emb(query, sincos, offset=offset) |
| 151 | |
| 152 | key = key.permute(0, 2, 1, 3) |
| 153 | query = query.permute(0, 2, 1, 3) |
| 154 | |
| 155 | if layer_past is not None: |
| 156 | past_key = layer_past[0] |
| 157 | past_value = layer_past[1] |
| 158 | key = jt.cat((past_key, key), dim=-2) |
| 159 | value = jt.cat((past_value, value), dim=-2) |
| 160 | |
| 161 | if use_cache is True: |
nothing calls this directly
no test coverage detected