(self, cur_epoch)
| 161 | raise NotImplementedError |
| 162 | |
| 163 | def write_epoch(self, cur_epoch): |
| 164 | basic_stats = self.basic() |
| 165 | |
| 166 | # Try to load customized metrics |
| 167 | task_stats = {} |
| 168 | for custom_metric in cfg.custom_metrics: |
| 169 | func = register.metric_dict.get(custom_metric) |
| 170 | if not func: |
| 171 | raise ValueError( |
| 172 | f'Unknown custom metric function name: {custom_metric}') |
| 173 | custom_metric_score = func(self._true, self._pred, self.task_type) |
| 174 | task_stats[custom_metric] = custom_metric_score |
| 175 | |
| 176 | if not task_stats: # use default metrics if no matching custom metric |
| 177 | if self.task_type == 'regression': |
| 178 | task_stats = self.regression() |
| 179 | elif self.task_type == 'classification_binary': |
| 180 | task_stats = self.classification_binary() |
| 181 | elif self.task_type == 'classification_multi': |
| 182 | task_stats = self.classification_multi() |
| 183 | else: |
| 184 | raise ValueError('Task has to be regression or classification') |
| 185 | |
| 186 | epoch_stats = {'epoch': cur_epoch} |
| 187 | eta_stats = {'eta': round(self.eta(cur_epoch), cfg.round)} |
| 188 | custom_stats = self.custom() |
| 189 | |
| 190 | if self.name == 'train': |
| 191 | stats = { |
| 192 | **epoch_stats, |
| 193 | **eta_stats, |
| 194 | **basic_stats, |
| 195 | **task_stats, |
| 196 | **custom_stats |
| 197 | } |
| 198 | else: |
| 199 | stats = { |
| 200 | **epoch_stats, |
| 201 | **basic_stats, |
| 202 | **task_stats, |
| 203 | **custom_stats |
| 204 | } |
| 205 | |
| 206 | |
| 207 | logging.info('{}: {}'.format(self.name, stats)) |
| 208 | # json |
| 209 | dict_to_json(stats, '{}/stats.json'.format(self.out_dir)) |
| 210 | # tensorboard |
| 211 | if cfg.tensorboard_each_run: |
| 212 | dict_to_tb(stats, self.tb_writer, cur_epoch) |
| 213 | self.reset() |
| 214 | |
| 215 | def close(self): |
| 216 | if cfg.tensorboard_each_run: |
no test coverage detected