return:(train_mask,val_mask,test_mask)
(ratio: str, masked_index: Tensor, total_node_num: int, seed: int = 1234567)
| 26 | |
| 27 | |
| 28 | def get_order(ratio: str, masked_index: Tensor, total_node_num: int, seed: int = 1234567): |
| 29 | """ |
| 30 | return:(train_mask,val_mask,test_mask) |
| 31 | """ |
| 32 | random.seed(seed) |
| 33 | |
| 34 | masked_node_num = len(masked_index) |
| 35 | shuffle_criterion = list(range(masked_node_num)) |
| 36 | random.shuffle(shuffle_criterion) |
| 37 | |
| 38 | train_val_test_list = [int(i) for i in ratio.split("-")] |
| 39 | tvt_sum = sum(train_val_test_list) |
| 40 | tvt_ratio_list = [i / tvt_sum for i in train_val_test_list] |
| 41 | |
| 42 | train_end_index = int(tvt_ratio_list[0] * masked_node_num) |
| 43 | val_end_index = train_end_index + int(tvt_ratio_list[1] * masked_node_num) |
| 44 | |
| 45 | train_mask_index = shuffle_criterion[:train_end_index] |
| 46 | val_mask_index = shuffle_criterion[train_end_index:val_end_index] |
| 47 | test_mask_index = shuffle_criterion[val_end_index:] |
| 48 | |
| 49 | train_mask = torch.zeros(total_node_num, dtype=torch.bool) |
| 50 | train_mask[masked_index[train_mask_index]] = True |
| 51 | val_mask = torch.zeros(total_node_num, dtype=torch.bool) |
| 52 | val_mask[masked_index[val_mask_index]] = True |
| 53 | test_mask = torch.zeros(total_node_num, dtype=torch.bool) |
| 54 | test_mask[masked_index[test_mask_index]] = True |
| 55 | |
| 56 | return (train_mask, val_mask, test_mask) |
| 57 | |
| 58 | |
| 59 | def check_train_containing(train_mask, y): |
no test coverage detected