MCPcopy
hub / github.com/deepspeedai/DeepSpeed / save_checkpoint

Method save_checkpoint

deepspeed/runtime/engine.py:4088–4182  ·  view source on GitHub ↗

Save training checkpoint Arguments: save_dir: Required. Directory for saving the checkpoint tag: Optional. Checkpoint tag used as a unique identifier for the checkpoint, global step is used if not provided. Tag name must be the same across all ranks.

(self, save_dir, tag=None, client_state={}, save_latest=True, exclude_frozen_parameters=False)

Source from the content-addressed store, hash-verified

4086 logger.warning(msg)
4087
4088 def save_checkpoint(self, save_dir, tag=None, client_state={}, save_latest=True, exclude_frozen_parameters=False):
4089 """Save training checkpoint
4090
4091 Arguments:
4092 save_dir: Required. Directory for saving the checkpoint
4093 tag: Optional. Checkpoint tag used as a unique identifier for the checkpoint, global step is
4094 used if not provided. Tag name must be the same across all ranks.
4095 client_state: Optional. State dictionary used for saving required training states in the client code.
4096 save_latest: Optional. Save a file 'latest' pointing to the latest saved checkpoint.
4097 exclude_frozen_parameters: Optional. Exclude frozen parameters from checkpointed state.
4098 Important: all processes must call this method and not just the process with rank 0. It is
4099 because each process needs to save its master weights and scheduler+optimizer states. This
4100 method will hang waiting to synchronize with other processes if it's called just for the
4101 process with rank 0.
4102
4103 """
4104 if self._optimizer_has_ckpt_event_prologue():
4105 # Custom preparation for checkpoint save, if applicable
4106 self.optimizer.checkpoint_event_prologue()
4107
4108 rank = self.local_rank if self.use_node_local_storage() else self.global_rank
4109
4110 # This is to make sure the checkpoint names are created without collision
4111 # There seems to be issue creating them in parallel
4112
4113 # Ensure save_dir directory exists
4114 if rank == 0:
4115 self.checkpoint_engine.makedirs(save_dir, exist_ok=True)
4116 dist.barrier()
4117
4118 if tag is None:
4119 tag = f"global_step{self.global_steps}"
4120
4121 # Ensure tag is a string
4122 tag = str(tag)
4123 commit_info = CheckpointCommitInfo(tag=tag, save_dir=save_dir, save_latest=save_latest)
4124
4125 self.checkpoint_engine.create(commit_info)
4126
4127 # Ensure checkpoint tag is consistent across ranks
4128 self._checkpoint_tag_validation(tag)
4129
4130 if self.has_moe_layers:
4131 self.save_non_zero_checkpoint = False
4132 self._create_checkpoint_file(save_dir, tag, False)
4133 self._save_moe_checkpoint(save_dir,
4134 tag,
4135 client_state=client_state,
4136 exclude_frozen_parameters=exclude_frozen_parameters)
4137
4138 # We distribute the task of saving layer checkpoint files among
4139 # data parallel instances, so all procs should call _save_checkpoint.
4140 # All procs then call module_state_dict(), but only procs of data
4141 # parallel rank 0 save the general model params.
4142 if not self.has_moe_layers:
4143 self._create_checkpoint_file(save_dir, tag, False)
4144 self._save_checkpoint(save_dir,
4145 tag,