(self, name: str, **kwargs: Dict[str, str])
| 33 | self.rank = comm.get_rank() |
| 34 | |
| 35 | def save(self, name: str, **kwargs: Dict[str, str]): |
| 36 | if not self.save_dir or not self.save_to_disk: |
| 37 | return |
| 38 | |
| 39 | data = {} |
| 40 | data["model"] = { |
| 41 | "weight": self.model.weight.data, |
| 42 | "momentum": self.model.weight_mom, |
| 43 | } |
| 44 | for key, obj in self.checkpointables.items(): |
| 45 | data[key] = obj.state_dict() |
| 46 | data.update(kwargs) |
| 47 | |
| 48 | basename = f"{name}.pth" |
| 49 | save_file = os.path.join(self.save_dir, basename) |
| 50 | assert os.path.basename(save_file) == basename, basename |
| 51 | self.logger.info("Saving partial fc weights") |
| 52 | with PathManager.open(save_file, "wb") as f: |
| 53 | torch.save(data, f) |
| 54 | self.tag_last_checkpoint(basename) |
| 55 | |
| 56 | def _load_model(self, checkpoint: Any): |
| 57 | checkpoint_state_dict = checkpoint.pop("model") |
no test coverage detected