(self, iteration, update_best=False)
| 202 | } |
| 203 | |
| 204 | def save(self, iteration, update_best=False): |
| 205 | # Only save in main process |
| 206 | if not is_main_process(): |
| 207 | return |
| 208 | |
| 209 | ckpt_filepath = os.path.join( |
| 210 | self.models_foldername, "model_%d.ckpt" % iteration |
| 211 | ) |
| 212 | best_ckpt_filepath = os.path.join( |
| 213 | self.ckpt_foldername, self.ckpt_prefix + "best.ckpt" |
| 214 | ) |
| 215 | |
| 216 | best_iteration = self.trainer.early_stopping.best_monitored_iteration |
| 217 | best_metric = self.trainer.early_stopping.best_monitored_value |
| 218 | |
| 219 | ckpt = { |
| 220 | "model": self.trainer.model.state_dict(), |
| 221 | "optimizer": self.trainer.optimizer.state_dict(), |
| 222 | "best_iteration": best_iteration, |
| 223 | "best_metric_value": best_metric, |
| 224 | "config": self.config, |
| 225 | } |
| 226 | |
| 227 | git_metadata_dict = self._get_vcs_fields() |
| 228 | ckpt.update(git_metadata_dict) |
| 229 | |
| 230 | torch.save(ckpt, ckpt_filepath) |
| 231 | |
| 232 | if update_best: |
| 233 | torch.save(ckpt, best_ckpt_filepath) |
| 234 | |
| 235 | def restore(self): |
| 236 | self.trainer.writer.write("Restoring checkpoint") |
no test coverage detected