MCPcopy Index your code
hub / github.com/williamleif/GraphSAGE / run_regression

Function run_regression

eval_scripts/citation_eval.py:19–29  ·  view source on GitHub ↗
(train_embeds, train_labels, test_embeds, test_labels)

Source from the content-addressed store, hash-verified

17 return classes
18
19def run_regression(train_embeds, train_labels, test_embeds, test_labels):
20 np.random.seed(1)
21 from sklearn.linear_model import SGDClassifier
22 from sklearn.dummy import DummyClassifier
23 from sklearn.metrics import f1_score
24 dummy = DummyClassifier()
25 dummy.fit(train_embeds, train_labels)
26 log = SGDClassifier(loss="log", n_jobs=10)
27 log.fit(train_embeds, train_labels)
28 print("F1 score:", f1_score(test_labels, log.predict(test_embeds), average="micro"))
29 print("Random baseline f1 score:", f1_score(test_labels, dummy.predict(test_embeds), average="micro"))
30
31if __name__ == '__main__':
32 parser = ArgumentParser("Run evaluation on citation data.")

Callers 1

citation_eval.pyFile · 0.70

Calls 1

predictMethod · 0.45

Tested by

no test coverage detected