Used for producing the results of Experiment 3.2 in the BGRL paper.
(self)
| 126 | self._labels = torch.cat([self._labels, y]) |
| 127 | |
| 128 | def evaluate(self): |
| 129 | """ |
| 130 | Used for producing the results of Experiment 3.2 in the BGRL paper. |
| 131 | """ |
| 132 | test_accs = [] |
| 133 | |
| 134 | self._embeddings = self._embeddings.cpu().numpy() |
| 135 | self._labels = self._labels.cpu().numpy() |
| 136 | self._dataset.to(torch.device("cpu")) |
| 137 | |
| 138 | one_hot_encoder = OneHotEncoder(categories='auto', sparse=False) |
| 139 | self._labels = one_hot_encoder.fit_transform(self._labels.reshape(-1, 1)).astype(np.bool) |
| 140 | |
| 141 | self._embeddings = normalize(self._embeddings, norm='l2') |
| 142 | |
| 143 | for i in range(20): |
| 144 | |
| 145 | self._train_mask = self._dataset.train_mask[i] |
| 146 | self._dev_mask = self._dataset.val_mask[i] |
| 147 | if self._args.name in ["WikiCS"]: |
| 148 | self._test_mask = self._dataset.test_mask |
| 149 | else: |
| 150 | self._test_mask = self._dataset.test_mask[i] |
| 151 | |
| 152 | # grid search with one-vs-rest classifiers |
| 153 | best_test_acc, best_acc = 0, 0 |
| 154 | |
| 155 | for c in 2.0 ** np.arange(-10, 11): |
| 156 | clf = OneVsRestClassifier(LogisticRegression(solver='liblinear', C=c)) |
| 157 | clf.fit(self._embeddings[self._train_mask], self._labels[self._train_mask]) |
| 158 | |
| 159 | y_pred = clf.predict_proba(self._embeddings[self._dev_mask]) |
| 160 | y_pred = np.argmax(y_pred, axis=1) |
| 161 | y_pred = one_hot_encoder.transform(y_pred.reshape(-1, 1)).astype(np.bool) |
| 162 | val_acc = metrics.accuracy_score(self._labels[self._dev_mask], y_pred) |
| 163 | if val_acc > best_acc: |
| 164 | best_acc = val_acc |
| 165 | y_pred = clf.predict_proba(self._embeddings[self._test_mask]) |
| 166 | y_pred = np.argmax(y_pred, axis=1) |
| 167 | y_pred = one_hot_encoder.transform(y_pred.reshape(-1, 1)).astype(np.bool) |
| 168 | best_test_acc = metrics.accuracy_score(self._labels[self._test_mask], y_pred) |
| 169 | test_accs.append(best_test_acc) |
| 170 | return np.mean(test_accs), np.std(test_accs) |
| 171 | |
| 172 | |
| 173 | def train_eval(args): |
no test coverage detected