MCPcopy
hub / github.com/snap-stanford/GraphGym / classification_binary

Method classification_binary

graphgym/logger.py:96–112  ·  view source on GitHub ↗
(self)

Source from the content-addressed store, hash-verified

94
95 # task properties
96 def classification_binary(self):
97 from sklearn.metrics import (accuracy_score, f1_score, precision_score,
98 recall_score, roc_auc_score)
99
100 true, pred_score = torch.cat(self._true), torch.cat(self._pred)
101 pred_int = self._get_pred_int(pred_score)
102 try:
103 r_a_score = roc_auc_score(true, pred_score)
104 except ValueError:
105 r_a_score = 0.0
106 return {
107 'accuracy': round(accuracy_score(true, pred_int), cfg.round),
108 'precision': round(precision_score(true, pred_int), cfg.round),
109 'recall': round(recall_score(true, pred_int), cfg.round),
110 'f1': round(f1_score(true, pred_int), cfg.round),
111 'auc': round(r_a_score, cfg.round),
112 }
113
114 def classification_multi(self):
115 from sklearn.metrics import accuracy_score

Callers 1

write_epochMethod · 0.95

Calls 1

_get_pred_intMethod · 0.95

Tested by

no test coverage detected