MCPcopy Index your code
hub / github.com/microsoft/BitNet / preprocess_two_weights_tl2

Function preprocess_two_weights_tl2

utils/generate-dummy-bitnet-model.py:577–618  ·  view source on GitHub ↗
(M, K, weight_num, BM, BY, bm, by, weight, final_weight)

Source from the content-addressed store, hash-verified

575
576
577def preprocess_two_weights_tl2(M, K, weight_num, BM, BY, bm, by, weight, final_weight):
578 weight = np.reshape(weight, (weight_num // 2, 2))
579 hi_weight = np.multiply(np.split(weight, 2, axis=1)[0], 3)
580 lo_weight = np.split(weight, 2, axis=1)[1]
581
582 weight = np.reshape((hi_weight + lo_weight), weight_num // 2)
583
584 # row-major index
585 weight = weight + 4
586 weight = np.reshape(weight, (M, K // 2)).astype(np.uint8)
587
588 outer_BM_weights = np.split(weight, (M // BM), axis=0)
589 for outer_BM_weight in outer_BM_weights:
590 # split in col with size of by (32index * 3 == 96nums)
591 outer_BY_weights = np.split(outer_BM_weight, (K // BY), axis=1)
592 for outer_BY_weight in outer_BY_weights:
593 # split in row with size of bm (32)
594 inner_bm_weights = np.split(outer_BY_weight, (BM // bm), axis=0)
595 for inner_bm_weight in inner_bm_weights:
596 # split in col with size of by (2index * 2 == 4nums)
597 inner_by_weights = np.split(inner_bm_weight, (BY // by), axis=1)
598 for inner_by_weight in inner_by_weights:
599 func_weights = np.split(inner_by_weight, 2, axis=1)
600
601 left_weight = func_weights[0]
602 left_sub_weights = np.split(left_weight, 4, axis=0)
603 new_left_weight = np.reshape(
604 np.concatenate([left_sub_weights[0], left_sub_weights[2],
605 left_sub_weights[1], left_sub_weights[3]], axis=0, dtype=np.uint8),
606 (bm))
607
608 right_weight = func_weights[1]
609 right_sub_weights = np.split(right_weight, 4, axis=0)
610 new_right_weight = np.reshape(
611 np.concatenate([right_sub_weights[0], right_sub_weights[2],
612 right_sub_weights[1], right_sub_weights[3]], axis=0, dtype=np.uint8),
613 (bm))
614 hi_weight = new_left_weight.astype(np.uint8) << 4
615 lo_weight = new_right_weight
616 func_weight = hi_weight + lo_weight
617 func_weight = np.reshape(func_weight, bm * by // 4)
618 final_weight.append(func_weight)
619
620def preprocess_three_weights_tl2(M, K, weight_num, BM, BY, bm, by, weight, final_weight):
621 weight = np.reshape(weight, (weight_num // 3, 3))

Callers 1

preprocess_weights_tl2Function · 0.70

Calls 1

astypeMethod · 0.45

Tested by

no test coverage detected