| 6 | from argparse import ArgumentParser |
| 7 | |
| 8 | def run_regression(train_embeds, train_labels, test_embeds, test_labels): |
| 9 | np.random.seed(1) |
| 10 | from sklearn.linear_model import SGDClassifier |
| 11 | from sklearn.dummy import DummyClassifier |
| 12 | from sklearn.metrics import f1_score |
| 13 | dummy = DummyClassifier() |
| 14 | dummy.fit(train_embeds, train_labels) |
| 15 | log = SGDClassifier(loss="log", n_jobs=55) |
| 16 | log.fit(train_embeds, train_labels) |
| 17 | print("Test scores") |
| 18 | print(f1_score(test_labels, log.predict(test_embeds), average="micro")) |
| 19 | print("Train scores") |
| 20 | print(f1_score(train_labels, log.predict(train_embeds), average="micro")) |
| 21 | print("Random baseline") |
| 22 | print(f1_score(test_labels, dummy.predict(test_embeds), average="micro")) |
| 23 | |
| 24 | if __name__ == '__main__': |
| 25 | parser = ArgumentParser("Run evaluation on Reddit data.") |