MCPcopy Index your code
hub / github.com/dmlc/dgl / process

Method process

python/dgl/data/knowledge_graph.py:86–146  ·  view source on GitHub ↗

The original knowledge base is stored in triplets. This function will parse these triplets and build the DGLGraph.

(self)

Source from the content-addressed store, hash-verified

84 extract_archive(tgz_path, self.raw_path)
85
86 def process(self):
87 """
88 The original knowledge base is stored in triplets.
89 This function will parse these triplets and build the DGLGraph.
90 """
91 root_path = self.raw_path
92 entity_path = os.path.join(root_path, "entities.dict")
93 relation_path = os.path.join(root_path, "relations.dict")
94 train_path = os.path.join(root_path, "train.txt")
95 valid_path = os.path.join(root_path, "valid.txt")
96 test_path = os.path.join(root_path, "test.txt")
97 entity_dict = _read_dictionary(entity_path)
98 relation_dict = _read_dictionary(relation_path)
99 train = np.asarray(
100 _read_triplets_as_list(train_path, entity_dict, relation_dict)
101 )
102 valid = np.asarray(
103 _read_triplets_as_list(valid_path, entity_dict, relation_dict)
104 )
105 test = np.asarray(
106 _read_triplets_as_list(test_path, entity_dict, relation_dict)
107 )
108 num_nodes = len(entity_dict)
109 num_rels = len(relation_dict)
110 if self.verbose:
111 print("# entities: {}".format(num_nodes))
112 print("# relations: {}".format(num_rels))
113 print("# training edges: {}".format(train.shape[0]))
114 print("# validation edges: {}".format(valid.shape[0]))
115 print("# testing edges: {}".format(test.shape[0]))
116
117 # for compatability
118 self._train = train
119 self._valid = valid
120 self._test = test
121
122 self._num_nodes = num_nodes
123 self._num_rels = num_rels
124 # build graph
125 g, data = build_knowledge_graph(
126 num_nodes, num_rels, train, valid, test, reverse=self.reverse
127 )
128 (
129 etype,
130 ntype,
131 train_edge_mask,
132 valid_edge_mask,
133 test_edge_mask,
134 train_mask,
135 val_mask,
136 test_mask,
137 ) = data
138 g.edata["train_edge_mask"] = train_edge_mask
139 g.edata["valid_edge_mask"] = valid_edge_mask
140 g.edata["test_edge_mask"] = test_edge_mask
141 g.edata["train_mask"] = train_mask
142 g.edata["val_mask"] = val_mask
143 g.edata["test_mask"] = test_mask

Callers

nothing calls this directly

Calls 5

_read_dictionaryFunction · 0.85
_read_triplets_as_listFunction · 0.85
build_knowledge_graphFunction · 0.85
formatMethod · 0.80
joinMethod · 0.45

Tested by

no test coverage detected