Load and process the data.
(self)
| 68 | ) |
| 69 | |
| 70 | def process(self): |
| 71 | """Load and process the data.""" |
| 72 | try: |
| 73 | import torch |
| 74 | except ImportError: |
| 75 | raise ModuleNotFoundError( |
| 76 | "This dataset requires PyTorch to be the backend." |
| 77 | ) |
| 78 | |
| 79 | # Process node features and labels. |
| 80 | with open(f"{self.raw_path}/out1_node_feature_label.txt", "r") as f: |
| 81 | data = [x.split("\t") for x in f.read().split("\n")[1:-1]] |
| 82 | |
| 83 | rows, cols = [], [] |
| 84 | labels = torch.empty(len(data), dtype=torch.long) |
| 85 | for n_id, col, label in data: |
| 86 | col = [int(x) for x in col.split(",")] |
| 87 | rows += [int(n_id)] * len(col) |
| 88 | cols += col |
| 89 | |
| 90 | labels[int(n_id)] = int(label) |
| 91 | |
| 92 | row, col = torch.tensor(rows), torch.tensor(cols) |
| 93 | features = torch.zeros(len(data), int(col.max()) + 1) |
| 94 | features[row, col] = 1.0 |
| 95 | |
| 96 | self._num_classes = int(labels.max().item()) + 1 |
| 97 | |
| 98 | # Process graph structure. |
| 99 | with open(f"{self.raw_path}/out1_graph_edges.txt", "r") as f: |
| 100 | data = f.read().split("\n")[1:-1] |
| 101 | data = [[int(v) for v in r.split("\t")] for r in data] |
| 102 | dst, src = torch.tensor(data, dtype=torch.long).t().contiguous() |
| 103 | |
| 104 | self._g = graph((src, dst), num_nodes=features.size(0)) |
| 105 | self._g.ndata["feat"] = features |
| 106 | self._g.ndata["label"] = labels |
| 107 | |
| 108 | # Process 10 train/val/test node splits. |
| 109 | train_masks, val_masks, test_masks = [], [], [] |
| 110 | for i in range(10): |
| 111 | filepath = f"{self.raw_path}/{self.name}_split_0.6_0.2_{i}.npz" |
| 112 | f = np.load(filepath) |
| 113 | train_masks += [torch.from_numpy(f["train_mask"])] |
| 114 | val_masks += [torch.from_numpy(f["val_mask"])] |
| 115 | test_masks += [torch.from_numpy(f["test_mask"])] |
| 116 | self._g.ndata["train_mask"] = torch.stack(train_masks, dim=1).bool() |
| 117 | self._g.ndata["val_mask"] = torch.stack(val_masks, dim=1).bool() |
| 118 | self._g.ndata["test_mask"] = torch.stack(test_masks, dim=1).bool() |
| 119 | |
| 120 | def has_cache(self): |
| 121 | return os.path.exists(self.raw_path) |