MCPcopy Index your code
hub / github.com/microsoft/BitNet / preprocess_weights_tl2

Function preprocess_weights_tl2

utils/generate-dummy-bitnet-model.py:699–773  ·  view source on GitHub ↗
(
    w: np.ndarray,
    bits = 2,
    g    = 4,
)

Source from the content-addressed store, hash-verified

697
698
699def preprocess_weights_tl2(
700 w: np.ndarray,
701 bits = 2,
702 g = 4,
703) -> Tuple[np.ndarray, np.ndarray]:
704 M, K = w.shape
705 weight = w
706 weight = np.where(np.abs(weight) < 1e-6, 0, weight).astype(np.float32)
707 weight = np.sign(weight)
708 weight_num = np.prod(weight.shape)
709
710 # for three num 6 bit ->
711
712 # outer loop
713 KEMD = 1536
714 BMEMD = 256
715 BYEMD = 96
716
717 KGQA = 4096
718 BMGQA = 128
719 BYGQA = 96
720
721 # inner loop (32row 32num/16index)
722 bm3 = 32
723 by3 = 6
724
725 if K == KEMD:
726 BM3 = BMEMD
727 BY3 = BYEMD
728 elif K == KGQA:
729 BM3 = BMGQA
730 BY3 = BYGQA
731 else:
732 raise NotImplementedError
733
734 BM2 = BM3
735 BY2 = 32
736 # inner loop (32row 32num/16index)
737 bm2 = 32
738 by2 = 4
739
740 if (weight.shape[1] % BY3 != 0):
741 slice_k_idx = weight.shape[1] - weight.shape[1] % BY3
742 slice_weights = np.split(weight, [slice_k_idx], axis=1)
743 three_weight = slice_weights[0]
744 two_weight = slice_weights[1]
745 else:
746 three_weight = weight
747
748 final_weight = []
749
750 preprocess_three_weights_tl2(three_weight.shape[0],
751 three_weight.shape[1],
752 three_weight.shape[0] * three_weight.shape[1],
753 BM3,
754 BY3,
755 bm3,
756 by3,

Callers 1

transform_to_tl2Method · 0.70

Calls 3

astypeMethod · 0.45

Tested by

no test coverage detected