MCPcopy Index your code
hub / github.com/williamleif/GraphSAGE / run_regression

Function run_regression

eval_scripts/reddit_eval.py:8–22  ·  view source on GitHub ↗
(train_embeds, train_labels, test_embeds, test_labels)

Source from the content-addressed store, hash-verified

6from argparse import ArgumentParser
7
8def run_regression(train_embeds, train_labels, test_embeds, test_labels):
9 np.random.seed(1)
10 from sklearn.linear_model import SGDClassifier
11 from sklearn.dummy import DummyClassifier
12 from sklearn.metrics import f1_score
13 dummy = DummyClassifier()
14 dummy.fit(train_embeds, train_labels)
15 log = SGDClassifier(loss="log", n_jobs=55)
16 log.fit(train_embeds, train_labels)
17 print("Test scores")
18 print(f1_score(test_labels, log.predict(test_embeds), average="micro"))
19 print("Train scores")
20 print(f1_score(train_labels, log.predict(train_embeds), average="micro"))
21 print("Random baseline")
22 print(f1_score(test_labels, dummy.predict(test_embeds), average="micro"))
23
24if __name__ == '__main__':
25 parser = ArgumentParser("Run evaluation on Reddit data.")

Callers 1

reddit_eval.pyFile · 0.70

Calls 1

predictMethod · 0.45

Tested by

no test coverage detected