Save training checkpoint Arguments: save_dir: Required. Directory for saving the checkpoint tag: Optional. Checkpoint tag used as a unique identifier for the checkpoint, global step is used if not provided. Tag name must be the same across all ranks.
(self, save_dir, tag=None, client_state={}, save_latest=True, exclude_frozen_parameters=False)
| 4086 | logger.warning(msg) |
| 4087 | |
| 4088 | def save_checkpoint(self, save_dir, tag=None, client_state={}, save_latest=True, exclude_frozen_parameters=False): |
| 4089 | """Save training checkpoint |
| 4090 | |
| 4091 | Arguments: |
| 4092 | save_dir: Required. Directory for saving the checkpoint |
| 4093 | tag: Optional. Checkpoint tag used as a unique identifier for the checkpoint, global step is |
| 4094 | used if not provided. Tag name must be the same across all ranks. |
| 4095 | client_state: Optional. State dictionary used for saving required training states in the client code. |
| 4096 | save_latest: Optional. Save a file 'latest' pointing to the latest saved checkpoint. |
| 4097 | exclude_frozen_parameters: Optional. Exclude frozen parameters from checkpointed state. |
| 4098 | Important: all processes must call this method and not just the process with rank 0. It is |
| 4099 | because each process needs to save its master weights and scheduler+optimizer states. This |
| 4100 | method will hang waiting to synchronize with other processes if it's called just for the |
| 4101 | process with rank 0. |
| 4102 | |
| 4103 | """ |
| 4104 | if self._optimizer_has_ckpt_event_prologue(): |
| 4105 | # Custom preparation for checkpoint save, if applicable |
| 4106 | self.optimizer.checkpoint_event_prologue() |
| 4107 | |
| 4108 | rank = self.local_rank if self.use_node_local_storage() else self.global_rank |
| 4109 | |
| 4110 | # This is to make sure the checkpoint names are created without collision |
| 4111 | # There seems to be issue creating them in parallel |
| 4112 | |
| 4113 | # Ensure save_dir directory exists |
| 4114 | if rank == 0: |
| 4115 | self.checkpoint_engine.makedirs(save_dir, exist_ok=True) |
| 4116 | dist.barrier() |
| 4117 | |
| 4118 | if tag is None: |
| 4119 | tag = f"global_step{self.global_steps}" |
| 4120 | |
| 4121 | # Ensure tag is a string |
| 4122 | tag = str(tag) |
| 4123 | commit_info = CheckpointCommitInfo(tag=tag, save_dir=save_dir, save_latest=save_latest) |
| 4124 | |
| 4125 | self.checkpoint_engine.create(commit_info) |
| 4126 | |
| 4127 | # Ensure checkpoint tag is consistent across ranks |
| 4128 | self._checkpoint_tag_validation(tag) |
| 4129 | |
| 4130 | if self.has_moe_layers: |
| 4131 | self.save_non_zero_checkpoint = False |
| 4132 | self._create_checkpoint_file(save_dir, tag, False) |
| 4133 | self._save_moe_checkpoint(save_dir, |
| 4134 | tag, |
| 4135 | client_state=client_state, |
| 4136 | exclude_frozen_parameters=exclude_frozen_parameters) |
| 4137 | |
| 4138 | # We distribute the task of saving layer checkpoint files among |
| 4139 | # data parallel instances, so all procs should call _save_checkpoint. |
| 4140 | # All procs then call module_state_dict(), but only procs of data |
| 4141 | # parallel rank 0 save the general model params. |
| 4142 | if not self.has_moe_layers: |
| 4143 | self._create_checkpoint_file(save_dir, tag, False) |
| 4144 | self._save_checkpoint(save_dir, |
| 4145 | tag, |