| 9 | super(MyNodeClassificationDataset, self).__init__(path) |
| 10 | |
| 11 | def process(self): |
| 12 | num_nodes = 100 |
| 13 | num_edges = 300 |
| 14 | feat_dim = 30 |
| 15 | |
| 16 | # load or generate your dataset |
| 17 | edge_index = torch.randint(0, num_nodes, (2, num_edges)) |
| 18 | x = torch.randn(num_nodes, feat_dim) |
| 19 | y = torch.randint(0, 2, (num_nodes,)) |
| 20 | |
| 21 | # set train/val/test mask in node_classification task |
| 22 | train_mask = torch.zeros(num_nodes).bool() |
| 23 | train_mask[0 : int(0.3 * num_nodes)] = True |
| 24 | val_mask = torch.zeros(num_nodes).bool() |
| 25 | val_mask[int(0.3 * num_nodes) : int(0.7 * num_nodes)] = True |
| 26 | test_mask = torch.zeros(num_nodes).bool() |
| 27 | test_mask[int(0.7 * num_nodes) :] = True |
| 28 | data = Graph(x=x, edge_index=edge_index, y=y, train_mask=train_mask, val_mask=val_mask, test_mask=test_mask) |
| 29 | torch.save(data, "mydata.pt") |
| 30 | return data |
| 31 | |
| 32 | |
| 33 | class MyGraphClassificationDataset(GraphDataset): |