(self)
| 88 | ) |
| 89 | |
| 90 | def process(self): |
| 91 | graph_file = os.path.join( |
| 92 | self.save_path, "{}_graph.json".format(self.mode) |
| 93 | ) |
| 94 | label_file = os.path.join( |
| 95 | self.save_path, "{}_labels.npy".format(self.mode) |
| 96 | ) |
| 97 | feat_file = os.path.join( |
| 98 | self.save_path, "{}_feats.npy".format(self.mode) |
| 99 | ) |
| 100 | graph_id_file = os.path.join( |
| 101 | self.save_path, "{}_graph_id.npy".format(self.mode) |
| 102 | ) |
| 103 | |
| 104 | g_data = json.load(open(graph_file)) |
| 105 | self._labels = np.load(label_file) |
| 106 | self._feats = np.load(feat_file) |
| 107 | self.graph = from_networkx( |
| 108 | nx.DiGraph(json_graph.node_link_graph(g_data)) |
| 109 | ) |
| 110 | graph_id = np.load(graph_id_file) |
| 111 | |
| 112 | # lo, hi means the range of graph ids for different portion of the dataset, |
| 113 | # 20 graphs for training, 2 for validation and 2 for testing. |
| 114 | lo, hi = 1, 21 |
| 115 | if self.mode == "valid": |
| 116 | lo, hi = 21, 23 |
| 117 | elif self.mode == "test": |
| 118 | lo, hi = 23, 25 |
| 119 | |
| 120 | graph_masks = [] |
| 121 | self.graphs = [] |
| 122 | for g_id in range(lo, hi): |
| 123 | g_mask = np.where(graph_id == g_id)[0] |
| 124 | graph_masks.append(g_mask) |
| 125 | g = self.graph.subgraph(g_mask) |
| 126 | g.ndata["feat"] = F.tensor( |
| 127 | self._feats[g_mask], dtype=F.data_type_dict["float32"] |
| 128 | ) |
| 129 | g.ndata["label"] = F.tensor( |
| 130 | self._labels[g_mask], dtype=F.data_type_dict["float32"] |
| 131 | ) |
| 132 | self.graphs.append(g) |
| 133 | |
| 134 | @property |
| 135 | def graph_list_path(self): |
nothing calls this directly
no test coverage detected