MCPcopy
hub / github.com/pyg-team/pytorch_geometric / random_planetoid_splits

Function random_planetoid_splits

benchmark/citation/train_eval.py:19–40  ·  view source on GitHub ↗
(data, num_classes)

Source from the content-addressed store, hash-verified

17
18
19def random_planetoid_splits(data, num_classes):
20 # Set new random planetoid splits:
21 # * 20 * num_classes labels for training
22 # * 500 labels for validation
23 # * 1000 labels for testing
24
25 indices = []
26 for i in range(num_classes):
27 index = (data.y == i).nonzero().view(-1)
28 index = index[torch.randperm(index.size(0))]
29 indices.append(index)
30
31 train_index = torch.cat([i[:20] for i in indices], dim=0)
32
33 rest_index = torch.cat([i[20:] for i in indices], dim=0)
34 rest_index = rest_index[torch.randperm(rest_index.size(0))]
35
36 data.train_mask = index_to_mask(train_index, size=data.num_nodes)
37 data.val_mask = index_to_mask(rest_index[:500], size=data.num_nodes)
38 data.test_mask = index_to_mask(rest_index[500:1500], size=data.num_nodes)
39
40 return data
41
42
43def run_train(dataset, model, runs, epochs, lr, weight_decay, early_stopping,

Callers

nothing calls this directly

Calls 4

index_to_maskFunction · 0.90
viewMethod · 0.80
appendMethod · 0.80
sizeMethod · 0.45

Tested by

no test coverage detected