MCPcopy
hub / github.com/THUDM/CogDL / evaluate

Method evaluate

examples/bgrl/train.py:128–170  ·  view source on GitHub ↗

Used for producing the results of Experiment 3.2 in the BGRL paper.

(self)

Source from the content-addressed store, hash-verified

126 self._labels = torch.cat([self._labels, y])
127
128 def evaluate(self):
129 """
130 Used for producing the results of Experiment 3.2 in the BGRL paper.
131 """
132 test_accs = []
133
134 self._embeddings = self._embeddings.cpu().numpy()
135 self._labels = self._labels.cpu().numpy()
136 self._dataset.to(torch.device("cpu"))
137
138 one_hot_encoder = OneHotEncoder(categories='auto', sparse=False)
139 self._labels = one_hot_encoder.fit_transform(self._labels.reshape(-1, 1)).astype(np.bool)
140
141 self._embeddings = normalize(self._embeddings, norm='l2')
142
143 for i in range(20):
144
145 self._train_mask = self._dataset.train_mask[i]
146 self._dev_mask = self._dataset.val_mask[i]
147 if self._args.name in ["WikiCS"]:
148 self._test_mask = self._dataset.test_mask
149 else:
150 self._test_mask = self._dataset.test_mask[i]
151
152 # grid search with one-vs-rest classifiers
153 best_test_acc, best_acc = 0, 0
154
155 for c in 2.0 ** np.arange(-10, 11):
156 clf = OneVsRestClassifier(LogisticRegression(solver='liblinear', C=c))
157 clf.fit(self._embeddings[self._train_mask], self._labels[self._train_mask])
158
159 y_pred = clf.predict_proba(self._embeddings[self._dev_mask])
160 y_pred = np.argmax(y_pred, axis=1)
161 y_pred = one_hot_encoder.transform(y_pred.reshape(-1, 1)).astype(np.bool)
162 val_acc = metrics.accuracy_score(self._labels[self._dev_mask], y_pred)
163 if val_acc > best_acc:
164 best_acc = val_acc
165 y_pred = clf.predict_proba(self._embeddings[self._test_mask])
166 y_pred = np.argmax(y_pred, axis=1)
167 y_pred = one_hot_encoder.transform(y_pred.reshape(-1, 1)).astype(np.bool)
168 best_test_acc = metrics.accuracy_score(self._labels[self._test_mask], y_pred)
169 test_accs.append(best_test_acc)
170 return np.mean(test_accs), np.std(test_accs)
171
172
173def train_eval(args):

Callers 1

trainMethod · 0.95

Calls 3

LogisticRegressionClass · 0.70
toMethod · 0.45
deviceMethod · 0.45

Tested by

no test coverage detected