| 390 | ) |
| 391 | |
| 392 | def load(self): |
| 393 | gs, tensor_dict = utils.load_graphs( |
| 394 | os.path.join(self.save_path, "graph_{}.bin".format(self.hash)) |
| 395 | ) |
| 396 | self.g = gs[0] |
| 397 | self._train_graph = self.g |
| 398 | self._val_edges = ( |
| 399 | tensor_dict["val_pos_src"], |
| 400 | tensor_dict["val_pos_dst"], |
| 401 | ), (tensor_dict["val_neg_src"], tensor_dict["val_neg_dst"]) |
| 402 | self._test_edges = ( |
| 403 | tensor_dict["test_pos_src"], |
| 404 | tensor_dict["test_pos_dst"], |
| 405 | ), (tensor_dict["test_neg_src"], tensor_dict["test_neg_dst"]) |
| 406 | |
| 407 | with open( |
| 408 | os.path.join(self.save_path, "info_{}.json".format(self.hash)), "r" |
| 409 | ) as f: |
| 410 | info = json.load(f) |
| 411 | self.split_ratio = info["split_ratio"] |
| 412 | self.neg_ratio = info["neg_ratio"] |
| 413 | |
| 414 | def save(self): |
| 415 | tensor_dict = { |