MCPcopy
hub / github.com/FedML-AI/FedML / create_random_split

Function create_random_split

research/SpreadGNN/data/data_loader.py:23–74  ·  view source on GitHub ↗
(path)

Source from the content-addressed store, hash-verified

21 return adj_matrices, feature_matrices, labels
22
23def create_random_split(path):
24 adj_matrices, feature_matrices, labels = get_data(path)
25
26 # Random 80/10/10 split as suggested in the MoleculeNet whitepaper
27 train_range = (0, int(0.8 * len(adj_matrices)))
28 val_range = (
29 int(0.8 * len(adj_matrices)),
30 int(0.8 * len(adj_matrices)) + int(0.1 * len(adj_matrices)),
31 )
32 test_range = (
33 int(0.8 * len(adj_matrices)) + int(0.1 * len(adj_matrices)),
34 len(adj_matrices),
35 )
36
37 all_idxs = list(range(len(adj_matrices)))
38 random.shuffle(all_idxs)
39
40 train_adj_matrices = [
41 adj_matrices[all_idxs[i]] for i in range(train_range[0], train_range[1])
42 ]
43 train_feature_matrices = [
44 feature_matrices[all_idxs[i]] for i in range(train_range[0], train_range[1])
45 ]
46 train_labels = [labels[all_idxs[i]] for i in range(train_range[0], train_range[1])]
47
48 val_adj_matrices = [
49 adj_matrices[all_idxs[i]] for i in range(val_range[0], val_range[1])
50 ]
51 val_feature_matrices = [
52 feature_matrices[all_idxs[i]] for i in range(val_range[0], val_range[1])
53 ]
54 val_labels = [labels[all_idxs[i]] for i in range(val_range[0], val_range[1])]
55
56 test_adj_matrices = [
57 adj_matrices[all_idxs[i]] for i in range(test_range[0], test_range[1])
58 ]
59 test_feature_matrices = [
60 feature_matrices[all_idxs[i]] for i in range(test_range[0], test_range[1])
61 ]
62 test_labels = [labels[all_idxs[i]] for i in range(test_range[0], test_range[1])]
63
64 return (
65 train_adj_matrices,
66 train_feature_matrices,
67 train_labels,
68 val_adj_matrices,
69 val_feature_matrices,
70 val_labels,
71 test_adj_matrices,
72 test_feature_matrices,
73 test_labels,
74 )
75
76def create_non_uniform_split(args, idxs, client_number, is_train=True):
77 logging.info("create_non_uniform_split------------------------------------------")

Callers 2

get_dataloaderFunction · 0.70

Calls 1

get_dataFunction · 0.70

Tested by

no test coverage detected