| 30 | self.logger = logger |
| 31 | |
| 32 | def save(self, name, **kwargs): |
| 33 | if not self.save_dir: |
| 34 | return |
| 35 | |
| 36 | if not self.save_to_disk: |
| 37 | return |
| 38 | |
| 39 | data = {} |
| 40 | data["model"] = self.model.state_dict() |
| 41 | if self.optimizer is not None: |
| 42 | data["optimizer"] = self.optimizer.state_dict() |
| 43 | if self.scheduler is not None: |
| 44 | data["scheduler"] = self.scheduler.state_dict() |
| 45 | data.update(kwargs) |
| 46 | |
| 47 | save_file = os.path.join(self.save_dir, "{}.pth".format(name)) |
| 48 | self.logger.info("Saving checkpoint to {}".format(save_file)) |
| 49 | torch.save(data, save_file) |
| 50 | self.tag_last_checkpoint(save_file) |
| 51 | |
| 52 | def load(self, f=None): |
| 53 | if self.has_checkpoint(): |