(
actor_module,
model_config,
weight_converter,
transformer_config,
layer_name_mapping,
convert_qkv_gate_up_by_simple_split=True,
)
| 1049 | |
| 1050 | |
| 1051 | def per_tensor_generator( |
| 1052 | actor_module, |
| 1053 | model_config, |
| 1054 | weight_converter, |
| 1055 | transformer_config, |
| 1056 | layer_name_mapping, |
| 1057 | convert_qkv_gate_up_by_simple_split=True, |
| 1058 | ): |
| 1059 | from megatron.core import parallel_state as mpu |
| 1060 | |
| 1061 | pp_rank = mpu.get_pipeline_model_parallel_rank() |
| 1062 | ep_size = mpu.get_expert_model_parallel_world_size() |
| 1063 | etp_size = mpu.get_expert_tensor_parallel_world_size() |
| 1064 | ep_group = mpu.get_expert_model_parallel_group() |
| 1065 | etp_group = mpu.get_expert_tensor_parallel_group() |
| 1066 | vpp_size = len(actor_module) |
| 1067 | all_gather_group = mpu.get_tensor_model_parallel_group() |
| 1068 | all_gather_group_size = torch.distributed.get_world_size(group=all_gather_group) |
| 1069 | |
| 1070 | def tensor_generator(): |
| 1071 | for scan_vpp_idx in range(vpp_size): |
| 1072 | existing_keys = set() |
| 1073 | model = unwrap_model(actor_module[scan_vpp_idx]) |
| 1074 | for name, param in model.named_parameters(): |
| 1075 | existing_keys.add(name) |
| 1076 | yield name, param |
| 1077 | # note |
| 1078 | # there is a bug in megatron GPTModel |
| 1079 | # decoder.layers[n].mlp.router.expert_bias" in GPTModel is not registered in named_parameter, but in |
| 1080 | # state_dict(). for now we patch it by adding those keys to extra_keys. |
| 1081 | extra_keys = [x for x in model.state_dict().keys() if "_extra_state" not in x and x not in existing_keys] |
| 1082 | for name in extra_keys: |
| 1083 | yield name, model.state_dict()[name].to(get_device_id()) |
| 1084 | |
| 1085 | # we need first make all rank get full model information |
| 1086 | meta_info = [] |
| 1087 | for scan_vpp_idx in range(vpp_size): |
| 1088 | existing_keys = set() |
| 1089 | model = unwrap_model(actor_module[scan_vpp_idx]) |
| 1090 | for idx, (name, _) in enumerate(model.named_parameters()): |
| 1091 | existing_keys.add(name) |
| 1092 | meta_info.append((pp_rank, scan_vpp_idx, idx, name)) |
| 1093 | extra_keys = [x for x in model.state_dict().keys() if "_extra_state" not in x and x not in existing_keys] |
| 1094 | for name in extra_keys: |
| 1095 | meta_info.append((pp_rank, scan_vpp_idx, idx, name)) |
| 1096 | |
| 1097 | obj_spec_output = [None] * mpu.get_pipeline_model_parallel_world_size() |
| 1098 | torch.distributed.all_gather_object( |
| 1099 | object_list=obj_spec_output, obj=meta_info, group=mpu.get_pipeline_model_parallel_group() |
| 1100 | ) |
| 1101 | layer_list_meta = [item for sublist in obj_spec_output for item in sublist] |
| 1102 | |
| 1103 | gen_func = tensor_generator() |
| 1104 | |
| 1105 | # lazy load tensor for full model |
| 1106 | for cur_pp_rank, scan_vpp_idx, idx, name in layer_list_meta: |
| 1107 | if model_config.tie_word_embeddings and ("output_layers" in name): |
| 1108 | import warnings |
nothing calls this directly
no test coverage detected