MCPcopy
hub / github.com/InternLM/InternLM / load_optimizer_checkpoint

Function load_optimizer_checkpoint

internlm/utils/model_checkpoint.py:171–222  ·  view source on GitHub ↗

Load the optimizer state from the local file system or remote object storage Service (OSS). Args: optim (Optimizer): optimizer folder (str): The FS/OSS path where the optimizer will be stored.

(folder, optim)

Source from the content-addressed store, hash-verified

169
170
171def load_optimizer_checkpoint(folder, optim):
172 """Load the optimizer state from the local file system or remote
173 object storage Service (OSS).
174
175 Args:
176 optim (Optimizer): optimizer
177 folder (str): The FS/OSS path where the optimizer will be stored.
178 """
179
180 fns = get_fns(folder)
181 max_tp, max_pp, max_zero = 0, 0, 0
182 for fn in fns:
183 if fn.startswith("optimizer_") and not fn.endswith(".md5"):
184 _, tp, pp, zero = os.path.splitext(fn)[0].split("_")
185 max_zero = max(max_zero, int(zero[2:]))
186 max_tp = max(max_tp, int(tp[2:]))
187 max_pp = max(max_pp, int(pp[2:]))
188
189 zero_size = gpc.get_world_size(ParallelMode.ZERO1)
190 zero_rank = gpc.get_local_rank(ParallelMode.ZERO1)
191 tp_size = gpc.get_world_size(ParallelMode.TENSOR)
192 pp_size = gpc.get_world_size(ParallelMode.PIPELINE)
193
194 assert (
195 zero_size == max_zero + 1
196 ), f"The weights are save for {max_zero+1} data parallel, while current has {zero_size} zero broadcast range."
197 assert (
198 pp_size == max_pp + 1
199 ), f"The weights are save for {max_pp+1} pipelines, while current has {pp_size} pipelines"
200 assert (
201 tp_size == max_tp + 1
202 ), f"The weights are save for {max_tp+1} parallelism, while current has {tp_size} tensor parallelism"
203
204 fp = f"optimizer_tp{gpc.get_local_rank(ParallelMode.TENSOR)}_"
205 fp += f"pp{gpc.get_local_rank(ParallelMode.PIPELINE)}_"
206 fp += f"zo{zero_rank}.pt"
207 states = llm_load(os.path.join(folder, fp), map_location=get_current_device())
208
209 if isinstance(optim, HybridZeroOptimizer):
210 fp_meta = os.path.join(folder, optim.rank_unique_id)
211 try:
212 zero_devide_optim_plan = llm_load(fp_meta)
213 states.update({"zero_devide_optim_plan": zero_devide_optim_plan})
214 except Exception as e:
215 logger.warning(
216 f"Read zero optimzer split file '{fp_meta}', for '{e}'"
217 f"Please check whether loading ckpts are saved with the HybridZeroOptimizer."
218 )
219
220 optim.load_state_dict(states)
221 del states
222 torch.cuda.empty_cache()
223
224
225def load_sampler(ckpt_path: str, sampler):

Callers 1

try_resume_trainingMethod · 0.85

Calls 7

get_fnsFunction · 0.90
llm_loadFunction · 0.90
get_current_deviceFunction · 0.90
get_world_sizeMethod · 0.80
get_local_rankMethod · 0.80
updateMethod · 0.45
load_state_dictMethod · 0.45

Tested by

no test coverage detected