| 17 | return classes |
| 18 | |
| 19 | def run_regression(train_embeds, train_labels, test_embeds, test_labels): |
| 20 | np.random.seed(1) |
| 21 | from sklearn.linear_model import SGDClassifier |
| 22 | from sklearn.dummy import DummyClassifier |
| 23 | from sklearn.metrics import f1_score |
| 24 | dummy = DummyClassifier() |
| 25 | dummy.fit(train_embeds, train_labels) |
| 26 | log = SGDClassifier(loss="log", n_jobs=10) |
| 27 | log.fit(train_embeds, train_labels) |
| 28 | print("F1 score:", f1_score(test_labels, log.predict(test_embeds), average="micro")) |
| 29 | print("Random baseline f1 score:", f1_score(test_labels, dummy.predict(test_embeds), average="micro")) |
| 30 | |
| 31 | if __name__ == '__main__': |
| 32 | parser = ArgumentParser("Run evaluation on citation data.") |