fuse function for fusing weights (1) fuse_attention_qkv q => [q1,q2,q3,q4] k => [k1,k2,k3,k4] or [k1,k2] for GQA v => [v1,v2,v3,v4] or [v1,v2] for GQA fused weight => [q1,k1,v1,q2,k2,v2,q3,k3,v3,q4,k4,v4] or for GQA [q1,q2,k1,
(fuse_params, is_qkv=False, num_heads=None, num_key_value_heads=None)
| 618 | |
| 619 | def fuse_param_func(): |
| 620 | def fn(fuse_params, is_qkv=False, num_heads=None, num_key_value_heads=None): |
| 621 | """fuse function for fusing weights |
| 622 | |
| 623 | (1) fuse_attention_qkv |
| 624 | q => [q1,q2,q3,q4] |
| 625 | k => [k1,k2,k3,k4] or [k1,k2] for GQA |
| 626 | v => [v1,v2,v3,v4] or [v1,v2] for GQA |
| 627 | fused weight => [q1,k1,v1,q2,k2,v2,q3,k3,v3,q4,k4,v4] |
| 628 | or for GQA [q1,q2,k1,v1,q3,q4,k2,v2] |
| 629 | (2) fuse_attention_ffn |
| 630 | directly fuse weights to 1 parts |
| 631 | [gate_weight], [up_weight] => [gate_weight, up_weight] |
| 632 | |
| 633 | Args: |
| 634 | fuse_params (_type_): to be fused weights |
| 635 | is_qkv (bool, optional): for attention qkv weights. Defaults to False. |
| 636 | num_heads (_type_, optional): query heads. Defaults to None. |
| 637 | num_key_value_heads (_type_, optional): key and value heads. Defaults to None. |
| 638 | |
| 639 | Returns: |
| 640 | _type_: fused weights |
| 641 | """ |
| 642 | concat_fn = np.concatenate |
| 643 | split_fn = np.split |
| 644 | if isinstance(fuse_params[0], paddle.Tensor): |
| 645 | concat_fn = paddle.cat |
| 646 | split_fn = paddle.split |
| 647 | |
| 648 | if is_qkv: |
| 649 | # fuse_attention_qkv |
| 650 | assert num_heads, f"num_heads should be number of heads for Q, but got {num_heads}" |
| 651 | assert ( |
| 652 | num_key_value_heads |
| 653 | ), f"num_key_value_heads should be number of key_value_heads for K and V, but got {num_key_value_heads}" |
| 654 | assert ( |
| 655 | len(fuse_params) == 3 |
| 656 | ), f"fuse_params length is not equal 3, it should be Q K V list. but got length {len(fuse_params)}" |
| 657 | num_query_groups = num_heads // num_key_value_heads |
| 658 | q_list = split_fn(fuse_params[0], num_heads, axis=-1) |
| 659 | k_list = split_fn(fuse_params[1], num_key_value_heads, axis=-1) |
| 660 | v_list = split_fn(fuse_params[2], num_key_value_heads, axis=-1) |
| 661 | |
| 662 | qkv_pairs = [] |
| 663 | for i in range(num_key_value_heads): |
| 664 | qkv_pairs += q_list[i * num_query_groups : (i + 1) * num_query_groups] |
| 665 | qkv_pairs.append(k_list[i]) |
| 666 | qkv_pairs.append(v_list[i]) |
| 667 | return concat_fn(qkv_pairs, axis=-1) |
| 668 | else: |
| 669 | # fuse_attention_ffn |
| 670 | return concat_fn(fuse_params, axis=-1) |
| 671 | |
| 672 | return fn |
| 673 |
no test coverage detected