MCPcopy Index your code
hub / github.com/deepspeedai/DeepSpeed / load_checkpoint

Method load_checkpoint

deepspeed/runtime/engine.py:3733–3823  ·  view source on GitHub ↗

Load training checkpoint Arguments: load_dir: Required. Directory to load the checkpoint from tag: Checkpoint tag used as a unique identifier for checkpoint, if not provided will attempt to load tag in 'latest' file load_module_strict: Optional.

(self,
                        load_dir,
                        tag=None,
                        load_module_strict=True,
                        load_optimizer_states=True,
                        load_lr_scheduler_states=True,
                        load_module_only=False,
                        custom_load_fn=None)

Source from the content-addressed store, hash-verified

3731 return ckpt_files
3732
3733 def load_checkpoint(self,
3734 load_dir,
3735 tag=None,
3736 load_module_strict=True,
3737 load_optimizer_states=True,
3738 load_lr_scheduler_states=True,
3739 load_module_only=False,
3740 custom_load_fn=None):
3741 """
3742 Load training checkpoint
3743
3744 Arguments:
3745 load_dir: Required. Directory to load the checkpoint from
3746 tag: Checkpoint tag used as a unique identifier for checkpoint, if not provided will attempt to load tag in 'latest' file
3747 load_module_strict: Optional. Boolean to strictly enforce that the keys in state_dict of module and checkpoint match.
3748 load_optimizer_states: Optional. Boolean to load the training optimizer states from Checkpoint. Ex. ADAM's momentum and variance
3749 load_lr_scheduler_states: Optional. Boolean to add the learning rate scheduler states from Checkpoint.
3750 load_module_only: Optional. Boolean to load only the model weights from the checkpoint. Ex. warmstarting.
3751 custom_load_fn: Optional. Custom model load function.
3752
3753 Returns:
3754 A tuple of ``load_path`` and ``client_state``.
3755 *``load_path``: Path of the loaded checkpoint. ``None`` if loading the checkpoint failed.
3756 *``client_state``: State dictionary used for loading required training states in the client code.
3757
3758 Important: under ZeRO3, one cannot load checkpoint with ``engine.load_checkpoint()`` right
3759 after ``engine.save_checkpoint()``. It is because ``engine.module`` is partitioned, and
3760 ``load_checkpoint()`` wants a pristine model. If insisting to do so, please reinitialize engine
3761 before ``load_checkpoint()``.
3762
3763 """
3764
3765 if tag is None:
3766 latest_tag = "latest_universal" if self.load_universal_checkpoint() else "latest"
3767 latest_path = os.path.join(load_dir, latest_tag)
3768 if os.path.isfile(latest_path):
3769 with open(latest_path, "r") as fd:
3770 tag = fd.read().strip()
3771 else:
3772 if self.load_universal_checkpoint():
3773 raise ValueError(f'Invalid for universal checkpoint: {latest_path} does not exist')
3774 else:
3775 logger.warning(
3776 f"Unable to find latest file at {latest_path}, if trying to load latest "
3777 "checkpoint please ensure this file exists or pass an explicit checkpoint tag when loading a checkpoint."
3778 )
3779 return None, None
3780
3781 if self._optimizer_has_ckpt_event_prologue():
3782 # Prepare for checkpoint load by ensuring all parameters are partitioned
3783 self.optimizer.checkpoint_event_prologue()
3784
3785 load_path, client_states = self._load_checkpoint(load_dir,
3786 tag,
3787 load_module_strict=load_module_strict,
3788 load_optimizer_states=load_optimizer_states,
3789 load_lr_scheduler_states=load_lr_scheduler_states,
3790 load_module_only=load_module_only,

Tested by 15

test_missing_latestMethod · 0.64
_run_testMethod · 0.64
test_pp_basicMethod · 0.64
_testMethod · 0.64
test_gpt2_basicMethod · 0.64