MCPcopy
hub / github.com/THUDM/CogDL / process

Method process

tests/datasets/test_customized_data.py:11–30  ·  view source on GitHub ↗
(self)

Source from the content-addressed store, hash-verified

9 super(MyNodeClassificationDataset, self).__init__(path)
10
11 def process(self):
12 num_nodes = 100
13 num_edges = 300
14 feat_dim = 30
15
16 # load or generate your dataset
17 edge_index = torch.randint(0, num_nodes, (2, num_edges))
18 x = torch.randn(num_nodes, feat_dim)
19 y = torch.randint(0, 2, (num_nodes,))
20
21 # set train/val/test mask in node_classification task
22 train_mask = torch.zeros(num_nodes).bool()
23 train_mask[0 : int(0.3 * num_nodes)] = True
24 val_mask = torch.zeros(num_nodes).bool()
25 val_mask[int(0.3 * num_nodes) : int(0.7 * num_nodes)] = True
26 test_mask = torch.zeros(num_nodes).bool()
27 test_mask[int(0.7 * num_nodes) :] = True
28 data = Graph(x=x, edge_index=edge_index, y=y, train_mask=train_mask, val_mask=val_mask, test_mask=test_mask)
29 torch.save(data, "mydata.pt")
30 return data
31
32
33class MyGraphClassificationDataset(GraphDataset):

Callers

nothing calls this directly

Calls 1

GraphClass · 0.90

Tested by

no test coverage detected