(dataset, folds)
| 96 | |
| 97 | |
| 98 | def k_fold(dataset, folds): |
| 99 | skf = StratifiedKFold(folds, shuffle=True, random_state=12345) |
| 100 | |
| 101 | test_indices, train_indices = [], [] |
| 102 | for _, idx in skf.split(torch.zeros(len(dataset)), dataset.data.y): |
| 103 | test_indices.append(torch.from_numpy(idx).to(torch.long)) |
| 104 | |
| 105 | val_indices = [test_indices[i - 1] for i in range(folds)] |
| 106 | |
| 107 | for i in range(folds): |
| 108 | train_mask = torch.ones(len(dataset), dtype=torch.bool) |
| 109 | train_mask[test_indices[i]] = 0 |
| 110 | train_mask[val_indices[i]] = 0 |
| 111 | train_indices.append(train_mask.nonzero(as_tuple=False).view(-1)) |
| 112 | |
| 113 | return train_indices, test_indices, val_indices |
| 114 | |
| 115 | |
| 116 | def num_graphs(data): |
no test coverage detected