MCPcopy
hub / github.com/InternLM/InternLM / load_scheduler

Function load_scheduler

internlm/utils/model_checkpoint.py:244–267  ·  view source on GitHub ↗
(ckpt_path: str, lr_scheduler, optimizer, learning_rate, train_state: TrainState)

Source from the content-addressed store, hash-verified

242
243
244def load_scheduler(ckpt_path: str, lr_scheduler, optimizer, learning_rate, train_state: TrainState):
245 scheduler_states = llm_load(os.path.join(ckpt_path, "schedulder.pt"))
246 if learning_rate != scheduler_states["base_lrs"][0] and gpc.is_rank_for_log():
247 logger.warning(
248 f"Using new learning rate {learning_rate} to replace old learn rate {scheduler_states['base_lrs'][0]}."
249 )
250
251 base_lrs = copy.deepcopy(scheduler_states["base_lrs"])
252 scheduler_states["base_lrs"] = [learning_rate] * len(scheduler_states["base_lrs"])
253 if "after_scheduler_dict" in scheduler_states:
254 scheduler_states["after_scheduler_dict"]["base_lrs"] = [learning_rate] * len(
255 scheduler_states["after_scheduler_dict"]["base_lrs"]
256 )
257
258 lr_scheduler.load_state_dict(scheduler_states)
259 lr_scheduler.last_epoch = train_state.step_count + 1
260
261 ratios = [learning_rate / lr for lr in base_lrs]
262 for idx, param_group in enumerate(optimizer.param_groups):
263 param_group["lr"] = param_group["lr"] * ratios[idx]
264 torch.cuda.empty_cache()
265
266 if gpc.is_rank_for_log():
267 logger.info(f"reload load_scheduler:{lr_scheduler}")
268
269
270class CheckpointManager:

Callers 1

try_resume_trainingMethod · 0.85

Calls 3

llm_loadFunction · 0.90
is_rank_for_logMethod · 0.80
load_state_dictMethod · 0.45

Tested by

no test coverage detected