MCPcopy Index your code
hub / github.com/shenweichen/GraphEmbedding / batch_iter

Method batch_iter

ge/models/line.py:142–195  ·  view source on GitHub ↗
(self, node2idx)

Source from the content-addressed store, hash-verified

140 self.edge_accept, self.edge_alias = create_alias_table(norm_prob)
141
142 def batch_iter(self, node2idx):
143
144 edges = [(node2idx[x[0]], node2idx[x[1]]) for x in self.graph.edges()]
145
146 data_size = self.graph.number_of_edges()
147 shuffle_indices = np.random.permutation(np.arange(data_size))
148 # positive or negative mod
149 mod = 0
150 mod_size = 1 + self.negative_ratio
151 h = []
152 t = []
153 sign = 0
154 count = 0
155 start_index = 0
156 end_index = min(start_index + self.batch_size, data_size)
157 while True:
158 if mod == 0:
159
160 h = []
161 t = []
162 for i in range(start_index, end_index):
163 if random.random() >= self.edge_accept[shuffle_indices[i]]:
164 shuffle_indices[i] = self.edge_alias[shuffle_indices[i]]
165 cur_h = edges[shuffle_indices[i]][0]
166 cur_t = edges[shuffle_indices[i]][1]
167 h.append(cur_h)
168 t.append(cur_t)
169 sign = np.ones(len(h), dtype=np.float32)
170 else:
171 sign = np.ones(len(h), dtype=np.float32) * -1
172 t = []
173 for i in range(len(h)):
174 t.append(alias_sample(
175 self.node_accept, self.node_alias))
176
177 heads = np.asarray(h, dtype=np.int32)
178 tails = np.asarray(t, dtype=np.int32)
179 if self.order == 'all':
180 yield ((heads, tails), (sign, sign))
181 else:
182 yield ((heads, tails), (sign,))
183 mod += 1
184 mod %= mod_size
185 if mod == 0:
186 start_index = end_index
187 end_index = min(start_index + self.batch_size, data_size)
188
189 if start_index >= data_size:
190 count += 1
191 mod = 0
192 h = []
193 shuffle_indices = np.random.permutation(np.arange(data_size))
194 start_index = 0
195 end_index = min(start_index + self.batch_size, data_size)
196
197 def get_embeddings(self, ):
198 self._embeddings = {}

Callers 1

reset_modelMethod · 0.95

Calls 1

alias_sampleFunction · 0.85

Tested by

no test coverage detected