MCPcopy
hub / github.com/verl-project/verl / default_tp_concat_fn

Function default_tp_concat_fn

verl/utils/megatron_utils.py:969–1048  ·  view source on GitHub ↗

name: name of the parameter train_params: training parameters infer_params (Iterable[torch.Tensor]): a iterator towards list of parameters all-gathered from micro_dp_group model_config: huggingface model_config TODO(zhangchi.usc1992): currently, the implementation is adhoc. We c

(
    layer_name_mapping,
    name,
    train_params,
    infer_params,
    model_config,
    hf_config=None,
    convert_qkv_gate_up_by_simple_split=False,
)

Source from the content-addressed store, hash-verified

967
968
969def default_tp_concat_fn(
970 layer_name_mapping,
971 name,
972 train_params,
973 infer_params,
974 model_config,
975 hf_config=None,
976 convert_qkv_gate_up_by_simple_split=False,
977):
978 """
979 name: name of the parameter
980 train_params: training parameters
981 infer_params (Iterable[torch.Tensor]): a iterator towards list of parameters all-gathered from micro_dp_group
982 model_config: huggingface model_config
983 TODO(zhangchi.usc1992): currently, the implementation is adhoc. We can move this function to the model
984 definition so that it is model-agnostic. If the model doesn't implement this function,
985 we can throw an error to force user disable TP HybridEngine.
986 """
987 from megatron.core import mpu
988
989 train_tp_size = mpu.get_tensor_model_parallel_world_size()
990 if layer_name_mapping.get("qkv_layer_name") in name and "layer_norm" not in name:
991 # if the tensor is qkv, for each param on tp, split into q, k, v
992 # concat q, k, v separately.
993 q_lst = []
994 k_lst = []
995 v_lst = []
996 num_attention_heads = model_config.num_attention_heads
997 num_key_value_heads = model_config.num_key_value_heads
998 if "vision_model" in name:
999 num_attention_heads = hf_config.vision_config.num_heads
1000 num_key_value_heads = hf_config.vision_config.num_heads
1001 assert num_attention_heads % num_key_value_heads == 0
1002 num_q_per_kv = num_attention_heads // num_key_value_heads
1003 assert infer_params[0].shape[0] % (num_q_per_kv + 2) == 0, (
1004 f"param '{name}' shape '{infer_params[0].shape}' dim0 is not divisible by {num_q_per_kv + 2}"
1005 )
1006 kv_size_per_tp = infer_params[0].shape[0] // (num_q_per_kv + 2)
1007 split_size = [kv_size_per_tp * num_q_per_kv, kv_size_per_tp, kv_size_per_tp]
1008 for infer_param in infer_params:
1009 num_query_groups_per_partition = num_key_value_heads // train_tp_size
1010 for chunk in infer_param.chunk(num_query_groups_per_partition):
1011 split_size = [
1012 kv_size_per_tp * num_q_per_kv // num_query_groups_per_partition,
1013 kv_size_per_tp // num_query_groups_per_partition,
1014 kv_size_per_tp // num_query_groups_per_partition,
1015 ]
1016 q, k, v = chunk.split(split_size)
1017 q_lst.append(q)
1018 k_lst.append(k)
1019 v_lst.append(v)
1020 q = torch.cat(q_lst, dim=0)
1021 k = torch.cat(k_lst, dim=0)
1022 v = torch.cat(v_lst, dim=0)
1023 infer_params = torch.cat((q, k, v), dim=0) if not convert_qkv_gate_up_by_simple_split else [q, k, v]
1024
1025 elif (
1026 layer_name_mapping.get("gate_proj_layer_name") in name

Callers 1

per_tensor_generatorFunction · 0.85

Calls 4

appendMethod · 0.80
getMethod · 0.45
chunkMethod · 0.45
splitMethod · 0.45

Tested by

no test coverage detected