MCPcopy
hub / github.com/PaddlePaddle/PaddleFormers / fn

Function fn

paddleformers/transformers/conversion_utils.py:620–670  ·  view source on GitHub ↗

fuse function for fusing weights (1) fuse_attention_qkv q => [q1,q2,q3,q4] k => [k1,k2,k3,k4] or [k1,k2] for GQA v => [v1,v2,v3,v4] or [v1,v2] for GQA fused weight => [q1,k1,v1,q2,k2,v2,q3,k3,v3,q4,k4,v4] or for GQA [q1,q2,k1,

(fuse_params, is_qkv=False, num_heads=None, num_key_value_heads=None)

Source from the content-addressed store, hash-verified

618
619def fuse_param_func():
620 def fn(fuse_params, is_qkv=False, num_heads=None, num_key_value_heads=None):
621 """fuse function for fusing weights
622
623 (1) fuse_attention_qkv
624 q => [q1,q2,q3,q4]
625 k => [k1,k2,k3,k4] or [k1,k2] for GQA
626 v => [v1,v2,v3,v4] or [v1,v2] for GQA
627 fused weight => [q1,k1,v1,q2,k2,v2,q3,k3,v3,q4,k4,v4]
628 or for GQA [q1,q2,k1,v1,q3,q4,k2,v2]
629 (2) fuse_attention_ffn
630 directly fuse weights to 1 parts
631 [gate_weight], [up_weight] => [gate_weight, up_weight]
632
633 Args:
634 fuse_params (_type_): to be fused weights
635 is_qkv (bool, optional): for attention qkv weights. Defaults to False.
636 num_heads (_type_, optional): query heads. Defaults to None.
637 num_key_value_heads (_type_, optional): key and value heads. Defaults to None.
638
639 Returns:
640 _type_: fused weights
641 """
642 concat_fn = np.concatenate
643 split_fn = np.split
644 if isinstance(fuse_params[0], paddle.Tensor):
645 concat_fn = paddle.cat
646 split_fn = paddle.split
647
648 if is_qkv:
649 # fuse_attention_qkv
650 assert num_heads, f"num_heads should be number of heads for Q, but got {num_heads}"
651 assert (
652 num_key_value_heads
653 ), f"num_key_value_heads should be number of key_value_heads for K and V, but got {num_key_value_heads}"
654 assert (
655 len(fuse_params) == 3
656 ), f"fuse_params length is not equal 3, it should be Q K V list. but got length {len(fuse_params)}"
657 num_query_groups = num_heads // num_key_value_heads
658 q_list = split_fn(fuse_params[0], num_heads, axis=-1)
659 k_list = split_fn(fuse_params[1], num_key_value_heads, axis=-1)
660 v_list = split_fn(fuse_params[2], num_key_value_heads, axis=-1)
661
662 qkv_pairs = []
663 for i in range(num_key_value_heads):
664 qkv_pairs += q_list[i * num_query_groups : (i + 1) * num_query_groups]
665 qkv_pairs.append(k_list[i])
666 qkv_pairs.append(v_list[i])
667 return concat_fn(qkv_pairs, axis=-1)
668 else:
669 # fuse_attention_ffn
670 return concat_fn(fuse_params, axis=-1)
671
672 return fn
673

Callers 8

wrapperFunction · 0.50
wrapperFunction · 0.50
applyMethod · 0.50
_transformMethod · 0.50
_filterMethod · 0.50
_mapMethod · 0.50
_transformMethod · 0.50
_filterMethod · 0.50

Calls 7

naive_fuse_merge_tpFunction · 0.85
normal_fuse_merge_tpFunction · 0.85
naive_fuse_split_tpFunction · 0.85
normal_fuse_split_tpFunction · 0.85
appendMethod · 0.45

Tested by

no test coverage detected