(self)
| 94 | |
| 95 | # task properties |
| 96 | def classification_binary(self): |
| 97 | from sklearn.metrics import (accuracy_score, f1_score, precision_score, |
| 98 | recall_score, roc_auc_score) |
| 99 | |
| 100 | true, pred_score = torch.cat(self._true), torch.cat(self._pred) |
| 101 | pred_int = self._get_pred_int(pred_score) |
| 102 | try: |
| 103 | r_a_score = roc_auc_score(true, pred_score) |
| 104 | except ValueError: |
| 105 | r_a_score = 0.0 |
| 106 | return { |
| 107 | 'accuracy': round(accuracy_score(true, pred_int), cfg.round), |
| 108 | 'precision': round(precision_score(true, pred_int), cfg.round), |
| 109 | 'recall': round(recall_score(true, pred_int), cfg.round), |
| 110 | 'f1': round(f1_score(true, pred_int), cfg.round), |
| 111 | 'auc': round(r_a_score, cfg.round), |
| 112 | } |
| 113 | |
| 114 | def classification_multi(self): |
| 115 | from sklearn.metrics import accuracy_score |
no test coverage detected