Load training checkpoint Arguments: load_dir: Required. Directory to load the checkpoint from tag: Checkpoint tag used as a unique identifier for checkpoint, if not provided will attempt to load tag in 'latest' file load_module_strict: Optional.
(self,
load_dir,
tag=None,
load_module_strict=True,
load_optimizer_states=True,
load_lr_scheduler_states=True,
load_module_only=False,
custom_load_fn=None)
| 3731 | return ckpt_files |
| 3732 | |
| 3733 | def load_checkpoint(self, |
| 3734 | load_dir, |
| 3735 | tag=None, |
| 3736 | load_module_strict=True, |
| 3737 | load_optimizer_states=True, |
| 3738 | load_lr_scheduler_states=True, |
| 3739 | load_module_only=False, |
| 3740 | custom_load_fn=None): |
| 3741 | """ |
| 3742 | Load training checkpoint |
| 3743 | |
| 3744 | Arguments: |
| 3745 | load_dir: Required. Directory to load the checkpoint from |
| 3746 | tag: Checkpoint tag used as a unique identifier for checkpoint, if not provided will attempt to load tag in 'latest' file |
| 3747 | load_module_strict: Optional. Boolean to strictly enforce that the keys in state_dict of module and checkpoint match. |
| 3748 | load_optimizer_states: Optional. Boolean to load the training optimizer states from Checkpoint. Ex. ADAM's momentum and variance |
| 3749 | load_lr_scheduler_states: Optional. Boolean to add the learning rate scheduler states from Checkpoint. |
| 3750 | load_module_only: Optional. Boolean to load only the model weights from the checkpoint. Ex. warmstarting. |
| 3751 | custom_load_fn: Optional. Custom model load function. |
| 3752 | |
| 3753 | Returns: |
| 3754 | A tuple of ``load_path`` and ``client_state``. |
| 3755 | *``load_path``: Path of the loaded checkpoint. ``None`` if loading the checkpoint failed. |
| 3756 | *``client_state``: State dictionary used for loading required training states in the client code. |
| 3757 | |
| 3758 | Important: under ZeRO3, one cannot load checkpoint with ``engine.load_checkpoint()`` right |
| 3759 | after ``engine.save_checkpoint()``. It is because ``engine.module`` is partitioned, and |
| 3760 | ``load_checkpoint()`` wants a pristine model. If insisting to do so, please reinitialize engine |
| 3761 | before ``load_checkpoint()``. |
| 3762 | |
| 3763 | """ |
| 3764 | |
| 3765 | if tag is None: |
| 3766 | latest_tag = "latest_universal" if self.load_universal_checkpoint() else "latest" |
| 3767 | latest_path = os.path.join(load_dir, latest_tag) |
| 3768 | if os.path.isfile(latest_path): |
| 3769 | with open(latest_path, "r") as fd: |
| 3770 | tag = fd.read().strip() |
| 3771 | else: |
| 3772 | if self.load_universal_checkpoint(): |
| 3773 | raise ValueError(f'Invalid for universal checkpoint: {latest_path} does not exist') |
| 3774 | else: |
| 3775 | logger.warning( |
| 3776 | f"Unable to find latest file at {latest_path}, if trying to load latest " |
| 3777 | "checkpoint please ensure this file exists or pass an explicit checkpoint tag when loading a checkpoint." |
| 3778 | ) |
| 3779 | return None, None |
| 3780 | |
| 3781 | if self._optimizer_has_ckpt_event_prologue(): |
| 3782 | # Prepare for checkpoint load by ensuring all parameters are partitioned |
| 3783 | self.optimizer.checkpoint_event_prologue() |
| 3784 | |
| 3785 | load_path, client_states = self._load_checkpoint(load_dir, |
| 3786 | tag, |
| 3787 | load_module_strict=load_module_strict, |
| 3788 | load_optimizer_states=load_optimizer_states, |
| 3789 | load_lr_scheduler_states=load_lr_scheduler_states, |
| 3790 | load_module_only=load_module_only, |