| 12 | from cogdl.datasets import NodeDataset, generate_random_graph |
| 13 | |
| 14 | class MyNodeDataset(NodeDataset): |
| 15 | def __init__(self, path="data.pt"): |
| 16 | self.path = path |
| 17 | super(MyNodeDataset, self).__init__(path, scale_feat=False, metric="accuracy") |
| 18 | |
| 19 | def process(self): |
| 20 | """You need to load your dataset and transform to `Graph`""" |
| 21 | num_nodes, num_edges, feat_dim = 100, 300, 30 |
| 22 | |
| 23 | # load or generate your dataset |
| 24 | edge_index = torch.randint(0, num_nodes, (2, num_edges)) |
| 25 | x = torch.randn(num_nodes, feat_dim) |
| 26 | y = torch.randint(0, 2, (num_nodes,)) |
| 27 | |
| 28 | # set train/val/test mask in node_classification task |
| 29 | train_mask = torch.zeros(num_nodes).bool() |
| 30 | train_mask[0 : int(0.3 * num_nodes)] = True |
| 31 | val_mask = torch.zeros(num_nodes).bool() |
| 32 | val_mask[int(0.3 * num_nodes) : int(0.7 * num_nodes)] = True |
| 33 | test_mask = torch.zeros(num_nodes).bool() |
| 34 | test_mask[int(0.7 * num_nodes) :] = True |
| 35 | data = Graph(x=x, edge_index=edge_index, y=y, train_mask=train_mask, val_mask=val_mask, test_mask=test_mask) |
| 36 | return data |
| 37 | |
| 38 | if __name__ == "__main__": |
| 39 | # Train customized dataset via defining a new class |