MCPcopy Index your code
hub / github.com/shenweichen/GraphEmbedding / Classifier

Class Classifier

ge/classify.py:22–66  ·  view source on GitHub ↗

Source from the content-addressed store, hash-verified

20
21
22class Classifier(object):
23
24 def __init__(self, embeddings, clf):
25 self.embeddings = embeddings
26 self.clf = TopKRanker(clf)
27 self.binarizer = MultiLabelBinarizer(sparse_output=True)
28
29 def train(self, X, Y, Y_all):
30 self.binarizer.fit(Y_all)
31 X_train = [self.embeddings[x] for x in X]
32 Y = self.binarizer.transform(Y)
33 self.clf.fit(X_train, Y)
34
35 def evaluate(self, X, Y):
36 top_k_list = [len(l) for l in Y]
37 Y_ = self.predict(X, top_k_list)
38 Y = self.binarizer.transform(Y)
39 averages = ["micro", "macro", "samples", "weighted"]
40 results = {}
41 for average in averages:
42 results[average] = f1_score(Y, Y_, average=average)
43 results['acc'] = accuracy_score(Y, Y_)
44 print('-------------------')
45 print(results)
46 return results
47
48 def predict(self, X, top_k_list):
49 X_ = numpy.asarray([self.embeddings[x] for x in X])
50 Y = self.clf.predict(X_, top_k_list=top_k_list)
51 return Y
52
53 def split_train_evaluate(self, X, Y, train_precent, seed=0):
54 state = numpy.random.get_state()
55
56 training_size = int(train_precent * len(X))
57 numpy.random.seed(seed)
58 shuffle_indices = numpy.random.permutation(numpy.arange(len(X)))
59 X_train = [X[shuffle_indices[i]] for i in range(training_size)]
60 Y_train = [Y[shuffle_indices[i]] for i in range(training_size)]
61 X_test = [X[shuffle_indices[i]] for i in range(training_size, len(X))]
62 Y_test = [Y[shuffle_indices[i]] for i in range(training_size, len(X))]
63
64 self.train(X_train, Y_train, Y)
65 numpy.random.set_state(state)
66 return self.evaluate(X_test, Y_test)
67
68
69def read_node_label(filename, skip_head=False):

Callers 6

evaluate_embeddingsFunction · 0.90
evaluate_embeddingsFunction · 0.90
evaluate_embeddingsFunction · 0.90
evaluate_embeddingsFunction · 0.90
evaluate_embeddingsFunction · 0.90
evaluate_embeddingsFunction · 0.90

Calls

no outgoing calls

Tested by

no test coverage detected