Save 16bit model weights This method saves the 16bit model weights at the desired destination. Arguments: save_dir: Required. Directory for saving the model save_filename: Optional. Filename to save to. Defaults to ``pytorch_model.bin``
(self, save_dir, save_filename="pytorch_model.bin", exclude_frozen_parameters=False)
| 4748 | return self.save_16bit_model(save_dir, save_filename) |
| 4749 | |
| 4750 | def save_16bit_model(self, save_dir, save_filename="pytorch_model.bin", exclude_frozen_parameters=False): |
| 4751 | """ |
| 4752 | Save 16bit model weights |
| 4753 | |
| 4754 | This method saves the 16bit model weights at the desired destination. |
| 4755 | |
| 4756 | Arguments: |
| 4757 | save_dir: Required. Directory for saving the model |
| 4758 | save_filename: Optional. Filename to save to. Defaults to ``pytorch_model.bin`` |
| 4759 | exclude_frozen_parameters: Optional. Exclude frozen parameters from checkpointed state. |
| 4760 | |
| 4761 | Returns: |
| 4762 | ``True`` when a model has been saved, ``False`` otherwise. It will not be saved if |
| 4763 | stage3_gather_16bit_weights_on_model_save is ``False``. |
| 4764 | |
| 4765 | Important: all processes must call this method and not just the process with rank 0. It is |
| 4766 | because the processes need to work in sync to gather the weights. This method will hang |
| 4767 | waiting to synchronize with other processes if it's called just for the process with rank 0. |
| 4768 | |
| 4769 | """ |
| 4770 | |
| 4771 | path = os.path.join(save_dir, save_filename) |
| 4772 | |
| 4773 | if self.zero_optimization_partition_weights(): |
| 4774 | if self.zero_gather_16bit_weights_on_model_save(): |
| 4775 | # consolidation is expensive in time and memory and therefore isn't a default |
| 4776 | state_dict = self._zero3_consolidated_16bit_state_dict( |
| 4777 | exclude_frozen_parameters=exclude_frozen_parameters) |
| 4778 | else: |
| 4779 | # the model will be bogus if not consolidated so don't confuse the user by saving it |
| 4780 | logger.info( |
| 4781 | f"Did not save the model {path} because stage3_gather_16bit_weights_on_model_save is False") |
| 4782 | return False |
| 4783 | else: |
| 4784 | state_dict = self.module_state_dict(exclude_frozen_parameters=exclude_frozen_parameters) |
| 4785 | |
| 4786 | tag = f"global_step{self.global_steps}" |
| 4787 | tag = str(tag) |
| 4788 | commit_info = CheckpointCommitInfo(tag=tag, save_dir=save_dir, save_latest=False) |
| 4789 | self.checkpoint_engine.create(commit_info) |
| 4790 | |
| 4791 | if dist.get_rank() == 0: |
| 4792 | self.checkpoint_engine.makedirs(save_dir, exist_ok=True) |
| 4793 | logger.info(f"Saving model weights to {path}, tag: {tag}") |
| 4794 | self.checkpoint_engine.save(state_dict, path) |
| 4795 | |
| 4796 | self.checkpoint_engine.commit(commit_info) |
| 4797 | |
| 4798 | return True |
| 4799 | |
| 4800 | def empty_partition_cache(self): |
| 4801 | """ |