| 68 | self.test_step_outputs.clear() |
| 69 | |
| 70 | def preprocess_state_dict(self, state_dict): |
| 71 | new_state_dict = OrderedDict() |
| 72 | |
| 73 | metric_state_dict = self.metrics.state_dict() |
| 74 | loss_state_dict = self._losses.state_dict() |
| 75 | |
| 76 | for k, v in metric_state_dict.items(): |
| 77 | new_state_dict['metrics.' + k] = v |
| 78 | |
| 79 | for k, v in loss_state_dict.items(): |
| 80 | new_state_dict['_losses.' + k] = v |
| 81 | |
| 82 | for k, v in state_dict.items(): |
| 83 | if '_losses' not in k and 'Metrics' not in k: |
| 84 | new_state_dict[k] = v |
| 85 | |
| 86 | return new_state_dict |
| 87 | |
| 88 | def load_state_dict(self, state_dict, strict=True): |
| 89 | new_state_dict = self.preprocess_state_dict(state_dict) |