The original knowledge base is stored in triplets. This function will parse these triplets and build the DGLGraph.
(self)
| 84 | extract_archive(tgz_path, self.raw_path) |
| 85 | |
| 86 | def process(self): |
| 87 | """ |
| 88 | The original knowledge base is stored in triplets. |
| 89 | This function will parse these triplets and build the DGLGraph. |
| 90 | """ |
| 91 | root_path = self.raw_path |
| 92 | entity_path = os.path.join(root_path, "entities.dict") |
| 93 | relation_path = os.path.join(root_path, "relations.dict") |
| 94 | train_path = os.path.join(root_path, "train.txt") |
| 95 | valid_path = os.path.join(root_path, "valid.txt") |
| 96 | test_path = os.path.join(root_path, "test.txt") |
| 97 | entity_dict = _read_dictionary(entity_path) |
| 98 | relation_dict = _read_dictionary(relation_path) |
| 99 | train = np.asarray( |
| 100 | _read_triplets_as_list(train_path, entity_dict, relation_dict) |
| 101 | ) |
| 102 | valid = np.asarray( |
| 103 | _read_triplets_as_list(valid_path, entity_dict, relation_dict) |
| 104 | ) |
| 105 | test = np.asarray( |
| 106 | _read_triplets_as_list(test_path, entity_dict, relation_dict) |
| 107 | ) |
| 108 | num_nodes = len(entity_dict) |
| 109 | num_rels = len(relation_dict) |
| 110 | if self.verbose: |
| 111 | print("# entities: {}".format(num_nodes)) |
| 112 | print("# relations: {}".format(num_rels)) |
| 113 | print("# training edges: {}".format(train.shape[0])) |
| 114 | print("# validation edges: {}".format(valid.shape[0])) |
| 115 | print("# testing edges: {}".format(test.shape[0])) |
| 116 | |
| 117 | # for compatability |
| 118 | self._train = train |
| 119 | self._valid = valid |
| 120 | self._test = test |
| 121 | |
| 122 | self._num_nodes = num_nodes |
| 123 | self._num_rels = num_rels |
| 124 | # build graph |
| 125 | g, data = build_knowledge_graph( |
| 126 | num_nodes, num_rels, train, valid, test, reverse=self.reverse |
| 127 | ) |
| 128 | ( |
| 129 | etype, |
| 130 | ntype, |
| 131 | train_edge_mask, |
| 132 | valid_edge_mask, |
| 133 | test_edge_mask, |
| 134 | train_mask, |
| 135 | val_mask, |
| 136 | test_mask, |
| 137 | ) = data |
| 138 | g.edata["train_edge_mask"] = train_edge_mask |
| 139 | g.edata["valid_edge_mask"] = valid_edge_mask |
| 140 | g.edata["test_edge_mask"] = test_edge_mask |
| 141 | g.edata["train_mask"] = train_mask |
| 142 | g.edata["val_mask"] = val_mask |
| 143 | g.edata["test_mask"] = test_mask |
nothing calls this directly
no test coverage detected