MCPcopy
hub / github.com/THUDM/CogDL / MyNodeDataset

Class MyNodeDataset

docs/source/tutorial_cn/examples/3custom_dataset_cn.py:14–36  ·  view source on GitHub ↗

Source from the content-addressed store, hash-verified

12from cogdl.datasets import NodeDataset, generate_random_graph
13
14class MyNodeDataset(NodeDataset):
15 def __init__(self, path="data.pt"):
16 self.path = path
17 super(MyNodeDataset, self).__init__(path, scale_feat=False, metric="accuracy")
18
19 def process(self):
20 """You need to load your dataset and transform to `Graph`"""
21 num_nodes, num_edges, feat_dim = 100, 300, 30
22
23 # load or generate your dataset
24 edge_index = torch.randint(0, num_nodes, (2, num_edges))
25 x = torch.randn(num_nodes, feat_dim)
26 y = torch.randint(0, 2, (num_nodes,))
27
28 # set train/val/test mask in node_classification task
29 train_mask = torch.zeros(num_nodes).bool()
30 train_mask[0 : int(0.3 * num_nodes)] = True
31 val_mask = torch.zeros(num_nodes).bool()
32 val_mask[int(0.3 * num_nodes) : int(0.7 * num_nodes)] = True
33 test_mask = torch.zeros(num_nodes).bool()
34 test_mask[int(0.7 * num_nodes) :] = True
35 data = Graph(x=x, edge_index=edge_index, y=y, train_mask=train_mask, val_mask=val_mask, test_mask=test_mask)
36 return data
37
38if __name__ == "__main__":
39 # Train customized dataset via defining a new class

Callers 1

Calls

no outgoing calls

Tested by

no test coverage detected