(self, x, n_head, dim_head, mp_num)
| 42 | self.rotary_dim = config.rotary_dim |
| 43 | |
| 44 | def _split_heads(self, x, n_head, dim_head, mp_num): |
| 45 | reshaped = x.reshape(x.shape[:-1] + (n_head // mp_num, dim_head)) |
| 46 | reshaped = reshaped.reshape(x.shape[:-2] + (-1,) + reshaped.shape[-1:]) |
| 47 | return reshaped |
| 48 | |
| 49 | def _merge_heads(self, tensor, num_attention_heads, attn_head_size): |
| 50 | """ |