| 52 | from cogdl.datasets import GraphDataset |
| 53 | |
| 54 | class MyGraphDataset(GraphDataset): |
| 55 | def __init__(self, path="data.pt"): |
| 56 | self.path = path |
| 57 | super(MyGraphDataset, self).__init__(path, metric="accuracy") |
| 58 | |
| 59 | def process(self): |
| 60 | # Load and preprocess data |
| 61 | # Here we randomly generate several graphs for simplicity as an example |
| 62 | graphs = [] |
| 63 | for i in range(10): |
| 64 | edges = torch.randint(0, 20, (2, 30)) |
| 65 | label = torch.randint(0, 7, (1,)) |
| 66 | graphs.append(Graph(edge_index=edges, y=label)) |
| 67 | return graphs |
| 68 | |
| 69 | if __name__ == "__main__": |
| 70 | dataset = MyGraphDataset() |