MCPcopy
hub / github.com/InternLM/InternLM / load_model_checkpoint

Function load_model_checkpoint

internlm/utils/model_checkpoint.py:99–140  ·  view source on GitHub ↗

There should be weights with names similar to the following under the folder. - folder - model_tp{tp_rank}_pp{pp_rank}.pt If the tp is inconsistent with the saved one in the future use, the weight needs to be converted before loading.

(folder, model)

Source from the content-addressed store, hash-verified

97
98
99def load_model_checkpoint(folder, model):
100 """
101 There should be weights with names similar to the following under the folder.
102 - folder
103 - model_tp{tp_rank}_pp{pp_rank}.pt
104
105 If the tp is inconsistent with the saved one in the future use, the weight needs to be converted before loading.
106 """
107
108 tp_size = gpc.get_world_size(ParallelMode.TENSOR)
109 pp_size = gpc.get_world_size(ParallelMode.PIPELINE)
110 tp_rank = gpc.get_local_rank(ParallelMode.TENSOR)
111 pp_rank = gpc.get_local_rank(ParallelMode.PIPELINE)
112
113 fns = get_fns(folder)
114 max_pp, max_tp = 0, 0
115 for fn in fns:
116 if fn.startswith("model_t") and not fn.endswith(".md5"):
117 segements = os.path.splitext(fn)[0].split("_")
118 max_pp = max(max_pp, int(segements[-1][2:]))
119 max_tp = max(max_tp, int(segements[-2][2:]))
120
121 assert (
122 pp_size == max_pp + 1
123 ), f"The weights are save for {max_pp+1} pipelines, while current has {pp_size} pipelines"
124 assert (
125 tp_size == max_tp + 1
126 ), f"The weights are save for {max_tp+1} parallelism, while current has {tp_size} tensor parallelism"
127
128 should_load_name = f"model_tp{tp_rank}_pp{pp_rank}.pt"
129 fp = os.path.join(folder, should_load_name)
130 states = llm_load(fp, map_location=get_current_device())
131
132 missing_k, unexpected_keys = model.load_state_dict(states, strict=False)
133 if len(missing_k) != 0:
134 logger.warning(f"Warning: missing keys {missing_k}")
135 if len(unexpected_keys) != 0:
136 logger.warning(f"Warning: unexpected keys {unexpected_keys}")
137
138 # avoid to cuda oom, Ref: https://discuss.pytorch.org/t/load-state-dict-causes-memory-leak/36189/11
139 del states
140 torch.cuda.empty_cache()
141
142
143def save_optimizer_checkpoint(optim, state_path):

Callers 1

try_load_modelMethod · 0.85

Calls 6

get_fnsFunction · 0.90
llm_loadFunction · 0.90
get_current_deviceFunction · 0.90
get_world_sizeMethod · 0.80
get_local_rankMethod · 0.80
load_state_dictMethod · 0.45

Tested by

no test coverage detected