MCPcopy
hub / github.com/verl-project/verl / per_tensor_generator

Function per_tensor_generator

verl/utils/megatron_utils.py:1051–1193  ·  view source on GitHub ↗
(
    actor_module,
    model_config,
    weight_converter,
    transformer_config,
    layer_name_mapping,
    convert_qkv_gate_up_by_simple_split=True,
)

Source from the content-addressed store, hash-verified

1049
1050
1051def per_tensor_generator(
1052 actor_module,
1053 model_config,
1054 weight_converter,
1055 transformer_config,
1056 layer_name_mapping,
1057 convert_qkv_gate_up_by_simple_split=True,
1058):
1059 from megatron.core import parallel_state as mpu
1060
1061 pp_rank = mpu.get_pipeline_model_parallel_rank()
1062 ep_size = mpu.get_expert_model_parallel_world_size()
1063 etp_size = mpu.get_expert_tensor_parallel_world_size()
1064 ep_group = mpu.get_expert_model_parallel_group()
1065 etp_group = mpu.get_expert_tensor_parallel_group()
1066 vpp_size = len(actor_module)
1067 all_gather_group = mpu.get_tensor_model_parallel_group()
1068 all_gather_group_size = torch.distributed.get_world_size(group=all_gather_group)
1069
1070 def tensor_generator():
1071 for scan_vpp_idx in range(vpp_size):
1072 existing_keys = set()
1073 model = unwrap_model(actor_module[scan_vpp_idx])
1074 for name, param in model.named_parameters():
1075 existing_keys.add(name)
1076 yield name, param
1077 # note
1078 # there is a bug in megatron GPTModel
1079 # decoder.layers[n].mlp.router.expert_bias" in GPTModel is not registered in named_parameter, but in
1080 # state_dict(). for now we patch it by adding those keys to extra_keys.
1081 extra_keys = [x for x in model.state_dict().keys() if "_extra_state" not in x and x not in existing_keys]
1082 for name in extra_keys:
1083 yield name, model.state_dict()[name].to(get_device_id())
1084
1085 # we need first make all rank get full model information
1086 meta_info = []
1087 for scan_vpp_idx in range(vpp_size):
1088 existing_keys = set()
1089 model = unwrap_model(actor_module[scan_vpp_idx])
1090 for idx, (name, _) in enumerate(model.named_parameters()):
1091 existing_keys.add(name)
1092 meta_info.append((pp_rank, scan_vpp_idx, idx, name))
1093 extra_keys = [x for x in model.state_dict().keys() if "_extra_state" not in x and x not in existing_keys]
1094 for name in extra_keys:
1095 meta_info.append((pp_rank, scan_vpp_idx, idx, name))
1096
1097 obj_spec_output = [None] * mpu.get_pipeline_model_parallel_world_size()
1098 torch.distributed.all_gather_object(
1099 object_list=obj_spec_output, obj=meta_info, group=mpu.get_pipeline_model_parallel_group()
1100 )
1101 layer_list_meta = [item for sublist in obj_spec_output for item in sublist]
1102
1103 gen_func = tensor_generator()
1104
1105 # lazy load tensor for full model
1106 for cur_pp_rank, scan_vpp_idx, idx, name in layer_list_meta:
1107 if model_config.tie_word_embeddings and ("output_layers" in name):
1108 import warnings

Callers

nothing calls this directly

Calls 12

normalize_model_nameFunction · 0.90
unwrap_modelFunction · 0.85
tensor_generatorFunction · 0.85
default_tp_concat_fnFunction · 0.85
appendMethod · 0.80
state_dictMethod · 0.80
addMethod · 0.45
all_gatherMethod · 0.45
splitMethod · 0.45
convert_paramMethod · 0.45

Tested by

no test coverage detected