(
w: np.ndarray,
bits = 2,
g = 4,
)
| 697 | |
| 698 | |
| 699 | def preprocess_weights_tl2( |
| 700 | w: np.ndarray, |
| 701 | bits = 2, |
| 702 | g = 4, |
| 703 | ) -> Tuple[np.ndarray, np.ndarray]: |
| 704 | M, K = w.shape |
| 705 | weight = w |
| 706 | weight = np.where(np.abs(weight) < 1e-6, 0, weight).astype(np.float32) |
| 707 | weight = np.sign(weight) |
| 708 | weight_num = np.prod(weight.shape) |
| 709 | |
| 710 | # for three num 6 bit -> |
| 711 | |
| 712 | # outer loop |
| 713 | KEMD = 1536 |
| 714 | BMEMD = 256 |
| 715 | BYEMD = 96 |
| 716 | |
| 717 | KGQA = 4096 |
| 718 | BMGQA = 128 |
| 719 | BYGQA = 96 |
| 720 | |
| 721 | # inner loop (32row 32num/16index) |
| 722 | bm3 = 32 |
| 723 | by3 = 6 |
| 724 | |
| 725 | if K == KEMD: |
| 726 | BM3 = BMEMD |
| 727 | BY3 = BYEMD |
| 728 | elif K == KGQA: |
| 729 | BM3 = BMGQA |
| 730 | BY3 = BYGQA |
| 731 | else: |
| 732 | raise NotImplementedError |
| 733 | |
| 734 | BM2 = BM3 |
| 735 | BY2 = 32 |
| 736 | # inner loop (32row 32num/16index) |
| 737 | bm2 = 32 |
| 738 | by2 = 4 |
| 739 | |
| 740 | if (weight.shape[1] % BY3 != 0): |
| 741 | slice_k_idx = weight.shape[1] - weight.shape[1] % BY3 |
| 742 | slice_weights = np.split(weight, [slice_k_idx], axis=1) |
| 743 | three_weight = slice_weights[0] |
| 744 | two_weight = slice_weights[1] |
| 745 | else: |
| 746 | three_weight = weight |
| 747 | |
| 748 | final_weight = [] |
| 749 | |
| 750 | preprocess_three_weights_tl2(three_weight.shape[0], |
| 751 | three_weight.shape[1], |
| 752 | three_weight.shape[0] * three_weight.shape[1], |
| 753 | BM3, |
| 754 | BY3, |
| 755 | bm3, |
| 756 | by3, |
no test coverage detected