MCPcopy
hub / github.com/InternLM/InternLM / try_load_model

Method try_load_model

internlm/utils/model_checkpoint.py:495–531  ·  view source on GitHub ↗
(self, current_time="")

Source from the content-addressed store, hash-verified

493 return latest_checkpoint
494
495 def try_load_model(self, current_time=""):
496 model_load_path = None
497
498 if self.load_ckpt_folder and self.load_model_only_folder:
499 raise ValueError(
500 "Error, try to use both load_ckpt_folder and load_model_only_folder paths, \
501if you only need to load model weights (for example starting an SFT task for the first time), \
502set load_model_only_folder path, if you need to resume training from ckpt, \
503set load_ckpt_folder or use default value \
504(if is the default value, internlm will try to load the latest ckpt from save_ckpt_folder)"
505 )
506
507 if self.load_ckpt_folder:
508 if gpc.is_rank_for_log():
509 logger.info(
510 f"===========Resume training from `{self.load_ckpt_folder}` {current_time} on host:"
511 f"{socket.gethostname()}==========="
512 )
513 model_load_path = self.load_ckpt_folder
514 elif self.load_model_only_folder:
515 if gpc.is_rank_for_log():
516 logger.info(
517 f"===========Load Model from `{self.load_model_only_folder}` {current_time} on host:"
518 f"{socket.gethostname()}==========="
519 )
520 model_load_path = self.load_model_only_folder
521 else:
522 if gpc.is_rank_for_log():
523 logger.info(
524 f"===========New Run {current_time} on host:{socket.gethostname()},rank={gpc.get_global_rank()},"
525 f"tp={gpc.get_local_rank(ParallelMode.TENSOR)},pp={gpc.get_local_rank(ParallelMode.PIPELINE)},"
526 f"dp={gpc.get_local_rank(ParallelMode.DATA)}==========="
527 )
528
529 # Loading model weights must be done before zero is initialized.
530 if model_load_path is not None:
531 load_model_checkpoint(folder=model_load_path, model=self.model)
532
533 def try_resume_training(self, lr_scheduler, optimizer, lr, train_state, train_dl):
534 """Attempt to restore the training state of the last ckpt.

Callers 1

mainFunction · 0.95

Calls 4

load_model_checkpointFunction · 0.85
is_rank_for_logMethod · 0.80
get_global_rankMethod · 0.80
get_local_rankMethod · 0.80

Tested by

no test coverage detected