name: name of the parameter train_params: training parameters infer_params (Iterable[torch.Tensor]): a iterator towards list of parameters all-gathered from micro_dp_group model_config: huggingface model_config TODO(zhangchi.usc1992): currently, the implementation is adhoc. We c
(
layer_name_mapping,
name,
train_params,
infer_params,
model_config,
hf_config=None,
convert_qkv_gate_up_by_simple_split=False,
)
| 967 | |
| 968 | |
| 969 | def default_tp_concat_fn( |
| 970 | layer_name_mapping, |
| 971 | name, |
| 972 | train_params, |
| 973 | infer_params, |
| 974 | model_config, |
| 975 | hf_config=None, |
| 976 | convert_qkv_gate_up_by_simple_split=False, |
| 977 | ): |
| 978 | """ |
| 979 | name: name of the parameter |
| 980 | train_params: training parameters |
| 981 | infer_params (Iterable[torch.Tensor]): a iterator towards list of parameters all-gathered from micro_dp_group |
| 982 | model_config: huggingface model_config |
| 983 | TODO(zhangchi.usc1992): currently, the implementation is adhoc. We can move this function to the model |
| 984 | definition so that it is model-agnostic. If the model doesn't implement this function, |
| 985 | we can throw an error to force user disable TP HybridEngine. |
| 986 | """ |
| 987 | from megatron.core import mpu |
| 988 | |
| 989 | train_tp_size = mpu.get_tensor_model_parallel_world_size() |
| 990 | if layer_name_mapping.get("qkv_layer_name") in name and "layer_norm" not in name: |
| 991 | # if the tensor is qkv, for each param on tp, split into q, k, v |
| 992 | # concat q, k, v separately. |
| 993 | q_lst = [] |
| 994 | k_lst = [] |
| 995 | v_lst = [] |
| 996 | num_attention_heads = model_config.num_attention_heads |
| 997 | num_key_value_heads = model_config.num_key_value_heads |
| 998 | if "vision_model" in name: |
| 999 | num_attention_heads = hf_config.vision_config.num_heads |
| 1000 | num_key_value_heads = hf_config.vision_config.num_heads |
| 1001 | assert num_attention_heads % num_key_value_heads == 0 |
| 1002 | num_q_per_kv = num_attention_heads // num_key_value_heads |
| 1003 | assert infer_params[0].shape[0] % (num_q_per_kv + 2) == 0, ( |
| 1004 | f"param '{name}' shape '{infer_params[0].shape}' dim0 is not divisible by {num_q_per_kv + 2}" |
| 1005 | ) |
| 1006 | kv_size_per_tp = infer_params[0].shape[0] // (num_q_per_kv + 2) |
| 1007 | split_size = [kv_size_per_tp * num_q_per_kv, kv_size_per_tp, kv_size_per_tp] |
| 1008 | for infer_param in infer_params: |
| 1009 | num_query_groups_per_partition = num_key_value_heads // train_tp_size |
| 1010 | for chunk in infer_param.chunk(num_query_groups_per_partition): |
| 1011 | split_size = [ |
| 1012 | kv_size_per_tp * num_q_per_kv // num_query_groups_per_partition, |
| 1013 | kv_size_per_tp // num_query_groups_per_partition, |
| 1014 | kv_size_per_tp // num_query_groups_per_partition, |
| 1015 | ] |
| 1016 | q, k, v = chunk.split(split_size) |
| 1017 | q_lst.append(q) |
| 1018 | k_lst.append(k) |
| 1019 | v_lst.append(v) |
| 1020 | q = torch.cat(q_lst, dim=0) |
| 1021 | k = torch.cat(k_lst, dim=0) |
| 1022 | v = torch.cat(v_lst, dim=0) |
| 1023 | infer_params = torch.cat((q, k, v), dim=0) if not convert_qkv_gate_up_by_simple_split else [q, k, v] |
| 1024 | |
| 1025 | elif ( |
| 1026 | layer_name_mapping.get("gate_proj_layer_name") in name |
no test coverage detected