MCPcopy
hub / github.com/THUDM/CogDL / get_order

Function get_order

cogdl/datasets/rd2cd_data.py:28–56  ·  view source on GitHub ↗

return:(train_mask,val_mask,test_mask)

(ratio: str, masked_index: Tensor, total_node_num: int, seed: int = 1234567)

Source from the content-addressed store, hash-verified

26
27
28def get_order(ratio: str, masked_index: Tensor, total_node_num: int, seed: int = 1234567):
29 """
30 return:(train_mask,val_mask,test_mask)
31 """
32 random.seed(seed)
33
34 masked_node_num = len(masked_index)
35 shuffle_criterion = list(range(masked_node_num))
36 random.shuffle(shuffle_criterion)
37
38 train_val_test_list = [int(i) for i in ratio.split("-")]
39 tvt_sum = sum(train_val_test_list)
40 tvt_ratio_list = [i / tvt_sum for i in train_val_test_list]
41
42 train_end_index = int(tvt_ratio_list[0] * masked_node_num)
43 val_end_index = train_end_index + int(tvt_ratio_list[1] * masked_node_num)
44
45 train_mask_index = shuffle_criterion[:train_end_index]
46 val_mask_index = shuffle_criterion[train_end_index:val_end_index]
47 test_mask_index = shuffle_criterion[val_end_index:]
48
49 train_mask = torch.zeros(total_node_num, dtype=torch.bool)
50 train_mask[masked_index[train_mask_index]] = True
51 val_mask = torch.zeros(total_node_num, dtype=torch.bool)
52 val_mask[masked_index[val_mask_index]] = True
53 test_mask = torch.zeros(total_node_num, dtype=torch.bool)
54 test_mask[masked_index[test_mask_index]] = True
55
56 return (train_mask, val_mask, test_mask)
57
58
59def check_train_containing(train_mask, y):

Callers 1

get_whole_maskFunction · 0.85

Calls 1

shuffleMethod · 0.45

Tested by

no test coverage detected