(self, model: FSDPModule, optim: torch.optim.Optimizer)
| 197 | return {} |
| 198 | |
| 199 | def save(self, model: FSDPModule, optim: torch.optim.Optimizer): |
| 200 | model_state_dict = self._get_full_model_state_dict(model) |
| 201 | optim_state_dict = self._get_full_optimizer_state_dict(model, optim) |
| 202 | if torch.distributed.get_rank() == 0: |
| 203 | new_training_time = int(time.time() * 1000) |
| 204 | new_checkpoint_folder = f"{self.folder}/{'dcp_api' if self.dcp_api else 'dtensor_api'}/{new_training_time}" |
| 205 | new_model_checkpoint = f"{new_checkpoint_folder}/{MODEL_CHECKPOINT}" |
| 206 | new_optim_checkpoint = f"{new_checkpoint_folder}/{OPTIM_CHECKPOINT}" |
| 207 | os.makedirs(new_checkpoint_folder, exist_ok=True) |
| 208 | torch.save(model_state_dict, new_model_checkpoint) |
| 209 | torch.save(optim_state_dict, new_optim_checkpoint) |
no test coverage detected