MCPcopy Index your code
hub / github.com/dmlc/dgl / load

Method load

python/dgl/data/knowledge_graph.py:180–235  ·  view source on GitHub ↗
(self)

Source from the content-addressed store, hash-verified

178 )
179
180 def load(self):
181 graphs, _ = load_graphs(str(self.graph_path))
182
183 info = load_info(str(self.info_path))
184 self._num_nodes = info["num_nodes"]
185 self._num_rels = info["num_rels"]
186 self._g = graphs[0]
187 train_mask = self._g.edata["train_edge_mask"].numpy()
188 val_mask = self._g.edata["valid_edge_mask"].numpy()
189 test_mask = self._g.edata["test_edge_mask"].numpy()
190
191 # convert mask tensor into bool tensor if possible
192 self._g.edata["train_edge_mask"] = generate_mask_tensor(
193 self._g.edata["train_edge_mask"].numpy()
194 )
195 self._g.edata["valid_edge_mask"] = generate_mask_tensor(
196 self._g.edata["valid_edge_mask"].numpy()
197 )
198 self._g.edata["test_edge_mask"] = generate_mask_tensor(
199 self._g.edata["test_edge_mask"].numpy()
200 )
201 self._g.edata["train_mask"] = generate_mask_tensor(
202 self._g.edata["train_mask"].numpy()
203 )
204 self._g.edata["val_mask"] = generate_mask_tensor(
205 self._g.edata["val_mask"].numpy()
206 )
207 self._g.edata["test_mask"] = generate_mask_tensor(
208 self._g.edata["test_mask"].numpy()
209 )
210
211 # for compatability (with 0.4.x) generate train_idx, valid_idx and test_idx
212 etype = self._g.edata["etype"].numpy()
213 self._etype = etype
214 u, v = self._g.all_edges(form="uv")
215 u = u.numpy()
216 v = v.numpy()
217 train_idx = np.nonzero(train_mask == 1)
218 self._train = np.column_stack(
219 (u[train_idx], etype[train_idx], v[train_idx])
220 )
221 valid_idx = np.nonzero(val_mask == 1)
222 self._valid = np.column_stack(
223 (u[valid_idx], etype[valid_idx], v[valid_idx])
224 )
225 test_idx = np.nonzero(test_mask == 1)
226 self._test = np.column_stack(
227 (u[test_idx], etype[test_idx], v[test_idx])
228 )
229
230 if self.verbose:
231 print("# entities: {}".format(self.num_nodes))
232 print("# relations: {}".format(self.num_rels))
233 print("# training edges: {}".format(self._train.shape[0]))
234 print("# validation edges: {}".format(self._valid.shape[0]))
235 print("# testing edges: {}".format(self._test.shape[0]))
236
237 @property

Callers

nothing calls this directly

Calls 6

load_graphsFunction · 0.85
load_infoFunction · 0.85
generate_mask_tensorFunction · 0.85
all_edgesMethod · 0.80
nonzeroMethod · 0.80
formatMethod · 0.80

Tested by

no test coverage detected