(self)
| 237 | save_info(str(self.info_path), {"num_classes": self.num_classes}) |
| 238 | |
| 239 | def load(self): |
| 240 | graphs, _ = load_graphs(str(self.graph_path)) |
| 241 | |
| 242 | info = load_info(str(self.info_path)) |
| 243 | graph = graphs[0] |
| 244 | self._g = graph |
| 245 | # for compatability |
| 246 | graph = graph.clone() |
| 247 | graph.ndata.pop("train_mask") |
| 248 | graph.ndata.pop("val_mask") |
| 249 | graph.ndata.pop("test_mask") |
| 250 | graph.ndata.pop("feat") |
| 251 | graph.ndata.pop("label") |
| 252 | graph = to_networkx(graph) |
| 253 | |
| 254 | self._num_classes = info["num_classes"] |
| 255 | self._g.ndata["train_mask"] = generate_mask_tensor( |
| 256 | F.asnumpy(self._g.ndata["train_mask"]) |
| 257 | ) |
| 258 | self._g.ndata["val_mask"] = generate_mask_tensor( |
| 259 | F.asnumpy(self._g.ndata["val_mask"]) |
| 260 | ) |
| 261 | self._g.ndata["test_mask"] = generate_mask_tensor( |
| 262 | F.asnumpy(self._g.ndata["test_mask"]) |
| 263 | ) |
| 264 | # hack for mxnet compatability |
| 265 | |
| 266 | if self.verbose: |
| 267 | print(" NumNodes: {}".format(self._g.num_nodes())) |
| 268 | print(" NumEdges: {}".format(self._g.num_edges())) |
| 269 | print(" NumFeats: {}".format(self._g.ndata["feat"].shape[1])) |
| 270 | print(" NumClasses: {}".format(self.num_classes)) |
| 271 | print( |
| 272 | " NumTrainingSamples: {}".format( |
| 273 | F.nonzero_1d(self._g.ndata["train_mask"]).shape[0] |
| 274 | ) |
| 275 | ) |
| 276 | print( |
| 277 | " NumValidationSamples: {}".format( |
| 278 | F.nonzero_1d(self._g.ndata["val_mask"]).shape[0] |
| 279 | ) |
| 280 | ) |
| 281 | print( |
| 282 | " NumTestSamples: {}".format( |
| 283 | F.nonzero_1d(self._g.ndata["test_mask"]).shape[0] |
| 284 | ) |
| 285 | ) |
| 286 | |
| 287 | def __getitem__(self, idx): |
| 288 | assert idx == 0, "This dataset has only one graph" |
nothing calls this directly
no test coverage detected