MCPcopy
hub / github.com/pyg-team/pytorch_geometric / k_fold

Function k_fold

benchmark/kernel/train_eval.py:98–113  ·  view source on GitHub ↗
(dataset, folds)

Source from the content-addressed store, hash-verified

96
97
98def k_fold(dataset, folds):
99 skf = StratifiedKFold(folds, shuffle=True, random_state=12345)
100
101 test_indices, train_indices = [], []
102 for _, idx in skf.split(torch.zeros(len(dataset)), dataset.data.y):
103 test_indices.append(torch.from_numpy(idx).to(torch.long))
104
105 val_indices = [test_indices[i - 1] for i in range(folds)]
106
107 for i in range(folds):
108 train_mask = torch.ones(len(dataset), dtype=torch.bool)
109 train_mask[test_indices[i]] = 0
110 train_mask[val_indices[i]] = 0
111 train_indices.append(train_mask.nonzero(as_tuple=False).view(-1))
112
113 return train_indices, test_indices, val_indices
114
115
116def num_graphs(data):

Callers 1

Calls 3

appendMethod · 0.80
viewMethod · 0.80
toMethod · 0.45

Tested by

no test coverage detected