(path)
| 21 | return adj_matrices, feature_matrices, labels |
| 22 | |
| 23 | def create_random_split(path): |
| 24 | adj_matrices, feature_matrices, labels = get_data(path) |
| 25 | |
| 26 | # Random 80/10/10 split as suggested in the MoleculeNet whitepaper |
| 27 | train_range = (0, int(0.8 * len(adj_matrices))) |
| 28 | val_range = ( |
| 29 | int(0.8 * len(adj_matrices)), |
| 30 | int(0.8 * len(adj_matrices)) + int(0.1 * len(adj_matrices)), |
| 31 | ) |
| 32 | test_range = ( |
| 33 | int(0.8 * len(adj_matrices)) + int(0.1 * len(adj_matrices)), |
| 34 | len(adj_matrices), |
| 35 | ) |
| 36 | |
| 37 | all_idxs = list(range(len(adj_matrices))) |
| 38 | random.shuffle(all_idxs) |
| 39 | |
| 40 | train_adj_matrices = [ |
| 41 | adj_matrices[all_idxs[i]] for i in range(train_range[0], train_range[1]) |
| 42 | ] |
| 43 | train_feature_matrices = [ |
| 44 | feature_matrices[all_idxs[i]] for i in range(train_range[0], train_range[1]) |
| 45 | ] |
| 46 | train_labels = [labels[all_idxs[i]] for i in range(train_range[0], train_range[1])] |
| 47 | |
| 48 | val_adj_matrices = [ |
| 49 | adj_matrices[all_idxs[i]] for i in range(val_range[0], val_range[1]) |
| 50 | ] |
| 51 | val_feature_matrices = [ |
| 52 | feature_matrices[all_idxs[i]] for i in range(val_range[0], val_range[1]) |
| 53 | ] |
| 54 | val_labels = [labels[all_idxs[i]] for i in range(val_range[0], val_range[1])] |
| 55 | |
| 56 | test_adj_matrices = [ |
| 57 | adj_matrices[all_idxs[i]] for i in range(test_range[0], test_range[1]) |
| 58 | ] |
| 59 | test_feature_matrices = [ |
| 60 | feature_matrices[all_idxs[i]] for i in range(test_range[0], test_range[1]) |
| 61 | ] |
| 62 | test_labels = [labels[all_idxs[i]] for i in range(test_range[0], test_range[1])] |
| 63 | |
| 64 | return ( |
| 65 | train_adj_matrices, |
| 66 | train_feature_matrices, |
| 67 | train_labels, |
| 68 | val_adj_matrices, |
| 69 | val_feature_matrices, |
| 70 | val_labels, |
| 71 | test_adj_matrices, |
| 72 | test_feature_matrices, |
| 73 | test_labels, |
| 74 | ) |
| 75 | |
| 76 | def create_non_uniform_split(args, idxs, client_number, is_train=True): |
| 77 | logging.info("create_non_uniform_split------------------------------------------") |
no test coverage detected