MCPcopy Index your code
hub / github.com/shenweichen/GraphEmbedding / LINE

Class LINE

ge/models/line.py:72–233  ·  view source on GitHub ↗

Source from the content-addressed store, hash-verified

70
71
72class LINE:
73 def __init__(self, graph, embedding_size=8, negative_ratio=5, order='second', ):
74 """
75
76 :param graph:
77 :param embedding_size:
78 :param negative_ratio:
79 :param order: 'first','second','all'
80 """
81 if order not in ['first', 'second', 'all']:
82 raise ValueError('mode must be fisrt,second,or all')
83
84 self.graph = graph
85 self.idx2node, self.node2idx = preprocess_nxgraph(graph)
86 self.use_alias = True
87
88 self.rep_size = embedding_size
89 self.order = order
90
91 self._embeddings = {}
92 self.negative_ratio = negative_ratio
93 self.order = order
94
95 self.node_size = graph.number_of_nodes()
96 self.edge_size = graph.number_of_edges()
97 self.samples_per_epoch = self.edge_size * (1 + negative_ratio)
98
99 self._gen_sampling_table()
100 self.reset_model()
101
102 def reset_training_config(self, batch_size, times):
103 self.batch_size = batch_size
104 self.steps_per_epoch = (
105 (self.samples_per_epoch - 1) // self.batch_size + 1) * times
106
107 def reset_model(self, opt='adam'):
108
109 self.model, self.embedding_dict = create_model(
110 self.node_size, self.rep_size, self.order)
111 self.model.compile(opt, line_loss)
112 self.batch_it = self.batch_iter(self.node2idx)
113
114 def _gen_sampling_table(self):
115
116 # create sampling table for vertex
117 power = 0.75
118 numNodes = self.node_size
119 node_degree = np.zeros(numNodes) # out degree
120 node2idx = self.node2idx
121
122 for edge in self.graph.edges():
123 node_degree[node2idx[edge[0]]
124 ] += self.graph[edge[0]][edge[1]].get('weight', 1.0)
125
126 total_sum = sum([math.pow(node_degree[i], power)
127 for i in range(numNodes)])
128 norm_prob = [float(math.pow(node_degree[j], power)) /
129 total_sum for j in range(numNodes)]

Callers 2

test_LINEFunction · 0.90
mainFunction · 0.90

Calls

no outgoing calls

Tested by 1

test_LINEFunction · 0.72