| 70 | |
| 71 | |
| 72 | class LINE: |
| 73 | def __init__(self, graph, embedding_size=8, negative_ratio=5, order='second', ): |
| 74 | """ |
| 75 | |
| 76 | :param graph: |
| 77 | :param embedding_size: |
| 78 | :param negative_ratio: |
| 79 | :param order: 'first','second','all' |
| 80 | """ |
| 81 | if order not in ['first', 'second', 'all']: |
| 82 | raise ValueError('mode must be fisrt,second,or all') |
| 83 | |
| 84 | self.graph = graph |
| 85 | self.idx2node, self.node2idx = preprocess_nxgraph(graph) |
| 86 | self.use_alias = True |
| 87 | |
| 88 | self.rep_size = embedding_size |
| 89 | self.order = order |
| 90 | |
| 91 | self._embeddings = {} |
| 92 | self.negative_ratio = negative_ratio |
| 93 | self.order = order |
| 94 | |
| 95 | self.node_size = graph.number_of_nodes() |
| 96 | self.edge_size = graph.number_of_edges() |
| 97 | self.samples_per_epoch = self.edge_size * (1 + negative_ratio) |
| 98 | |
| 99 | self._gen_sampling_table() |
| 100 | self.reset_model() |
| 101 | |
| 102 | def reset_training_config(self, batch_size, times): |
| 103 | self.batch_size = batch_size |
| 104 | self.steps_per_epoch = ( |
| 105 | (self.samples_per_epoch - 1) // self.batch_size + 1) * times |
| 106 | |
| 107 | def reset_model(self, opt='adam'): |
| 108 | |
| 109 | self.model, self.embedding_dict = create_model( |
| 110 | self.node_size, self.rep_size, self.order) |
| 111 | self.model.compile(opt, line_loss) |
| 112 | self.batch_it = self.batch_iter(self.node2idx) |
| 113 | |
| 114 | def _gen_sampling_table(self): |
| 115 | |
| 116 | # create sampling table for vertex |
| 117 | power = 0.75 |
| 118 | numNodes = self.node_size |
| 119 | node_degree = np.zeros(numNodes) # out degree |
| 120 | node2idx = self.node2idx |
| 121 | |
| 122 | for edge in self.graph.edges(): |
| 123 | node_degree[node2idx[edge[0]] |
| 124 | ] += self.graph[edge[0]][edge[1]].get('weight', 1.0) |
| 125 | |
| 126 | total_sum = sum([math.pow(node_degree[i], power) |
| 127 | for i in range(numNodes)]) |
| 128 | norm_prob = [float(math.pow(node_degree[j], power)) / |
| 129 | total_sum for j in range(numNodes)] |