(M, K, weight_num, BM, BY, bm, by, weight, final_weight)
| 575 | |
| 576 | |
| 577 | def preprocess_two_weights_tl2(M, K, weight_num, BM, BY, bm, by, weight, final_weight): |
| 578 | weight = np.reshape(weight, (weight_num // 2, 2)) |
| 579 | hi_weight = np.multiply(np.split(weight, 2, axis=1)[0], 3) |
| 580 | lo_weight = np.split(weight, 2, axis=1)[1] |
| 581 | |
| 582 | weight = np.reshape((hi_weight + lo_weight), weight_num // 2) |
| 583 | |
| 584 | # row-major index |
| 585 | weight = weight + 4 |
| 586 | weight = np.reshape(weight, (M, K // 2)).astype(np.uint8) |
| 587 | |
| 588 | outer_BM_weights = np.split(weight, (M // BM), axis=0) |
| 589 | for outer_BM_weight in outer_BM_weights: |
| 590 | # split in col with size of by (32index * 3 == 96nums) |
| 591 | outer_BY_weights = np.split(outer_BM_weight, (K // BY), axis=1) |
| 592 | for outer_BY_weight in outer_BY_weights: |
| 593 | # split in row with size of bm (32) |
| 594 | inner_bm_weights = np.split(outer_BY_weight, (BM // bm), axis=0) |
| 595 | for inner_bm_weight in inner_bm_weights: |
| 596 | # split in col with size of by (2index * 2 == 4nums) |
| 597 | inner_by_weights = np.split(inner_bm_weight, (BY // by), axis=1) |
| 598 | for inner_by_weight in inner_by_weights: |
| 599 | func_weights = np.split(inner_by_weight, 2, axis=1) |
| 600 | |
| 601 | left_weight = func_weights[0] |
| 602 | left_sub_weights = np.split(left_weight, 4, axis=0) |
| 603 | new_left_weight = np.reshape( |
| 604 | np.concatenate([left_sub_weights[0], left_sub_weights[2], |
| 605 | left_sub_weights[1], left_sub_weights[3]], axis=0, dtype=np.uint8), |
| 606 | (bm)) |
| 607 | |
| 608 | right_weight = func_weights[1] |
| 609 | right_sub_weights = np.split(right_weight, 4, axis=0) |
| 610 | new_right_weight = np.reshape( |
| 611 | np.concatenate([right_sub_weights[0], right_sub_weights[2], |
| 612 | right_sub_weights[1], right_sub_weights[3]], axis=0, dtype=np.uint8), |
| 613 | (bm)) |
| 614 | hi_weight = new_left_weight.astype(np.uint8) << 4 |
| 615 | lo_weight = new_right_weight |
| 616 | func_weight = hi_weight + lo_weight |
| 617 | func_weight = np.reshape(func_weight, bm * by // 4) |
| 618 | final_weight.append(func_weight) |
| 619 | |
| 620 | def preprocess_three_weights_tl2(M, K, weight_num, BM, BY, bm, by, weight, final_weight): |
| 621 | weight = np.reshape(weight, (weight_num // 3, 3)) |
no test coverage detected