| 11 | |
| 12 | |
| 13 | def read_geom_data(folder, dataset_name): |
| 14 | graph_adjacency_list_file_path = osp.join(folder, "out1_graph_edges.txt") |
| 15 | graph_node_features_and_labels_file_path = osp.join(folder, "out1_node_feature_label.txt") |
| 16 | |
| 17 | G = nx.DiGraph() |
| 18 | graph_node_features_dict = {} |
| 19 | graph_labels_dict = {} |
| 20 | |
| 21 | if dataset_name == "film": |
| 22 | with open(graph_node_features_and_labels_file_path) as graph_node_features_and_labels_file: |
| 23 | graph_node_features_and_labels_file.readline() |
| 24 | for line in graph_node_features_and_labels_file: |
| 25 | line = line.rstrip().split("\t") |
| 26 | assert len(line) == 3 |
| 27 | assert int(line[0]) not in graph_node_features_dict and int(line[0]) not in graph_labels_dict |
| 28 | feature_blank = np.zeros(932, dtype=np.uint8) |
| 29 | feature_blank[np.array(line[1].split(","), dtype=np.uint16)] = 1 |
| 30 | graph_node_features_dict[int(line[0])] = feature_blank |
| 31 | graph_labels_dict[int(line[0])] = int(line[2]) |
| 32 | else: |
| 33 | with open(graph_node_features_and_labels_file_path) as graph_node_features_and_labels_file: |
| 34 | graph_node_features_and_labels_file.readline() |
| 35 | for line in graph_node_features_and_labels_file: |
| 36 | line = line.rstrip().split("\t") |
| 37 | assert len(line) == 3 |
| 38 | assert int(line[0]) not in graph_node_features_dict and int(line[0]) not in graph_labels_dict |
| 39 | graph_node_features_dict[int(line[0])] = np.array(line[1].split(","), dtype=np.uint8) |
| 40 | graph_labels_dict[int(line[0])] = int(line[2]) |
| 41 | |
| 42 | with open(graph_adjacency_list_file_path) as graph_adjacency_list_file: |
| 43 | graph_adjacency_list_file.readline() |
| 44 | for line in graph_adjacency_list_file: |
| 45 | line = line.rstrip().split("\t") |
| 46 | assert len(line) == 2 |
| 47 | if int(line[0]) not in G: |
| 48 | G.add_node( |
| 49 | int(line[0]), features=graph_node_features_dict[int(line[0])], label=graph_labels_dict[int(line[0])] |
| 50 | ) |
| 51 | if int(line[1]) not in G: |
| 52 | G.add_node( |
| 53 | int(line[1]), features=graph_node_features_dict[int(line[1])], label=graph_labels_dict[int(line[1])] |
| 54 | ) |
| 55 | G.add_edge(int(line[0]), int(line[1])) |
| 56 | |
| 57 | adj = nx.adjacency_matrix(G, sorted(G.nodes())) |
| 58 | features = np.array([features for _, features in sorted(G.nodes(data="features"), key=lambda x: x[0])]) |
| 59 | labels = np.array([label for _, label in sorted(G.nodes(data="label"), key=lambda x: x[0])]) |
| 60 | |
| 61 | all_masks = [] |
| 62 | for split in range(10): |
| 63 | graph_split_file_path = osp.join(folder, f"{dataset_name}_split_0.6_0.2_{split}.npz") |
| 64 | with np.load(graph_split_file_path) as splits_file: |
| 65 | train_mask = splits_file["train_mask"] |
| 66 | val_mask = splits_file["val_mask"] |
| 67 | test_mask = splits_file["test_mask"] |
| 68 | train_mask = torch.BoolTensor(train_mask) |
| 69 | val_mask = torch.BoolTensor(val_mask) |
| 70 | test_mask = torch.BoolTensor(test_mask) |