MCPcopy
hub / github.com/THUDM/CogDL / read_geom_data

Function read_geom_data

cogdl/datasets/geom_data.py:13–82  ·  view source on GitHub ↗
(folder, dataset_name)

Source from the content-addressed store, hash-verified

11
12
13def read_geom_data(folder, dataset_name):
14 graph_adjacency_list_file_path = osp.join(folder, "out1_graph_edges.txt")
15 graph_node_features_and_labels_file_path = osp.join(folder, "out1_node_feature_label.txt")
16
17 G = nx.DiGraph()
18 graph_node_features_dict = {}
19 graph_labels_dict = {}
20
21 if dataset_name == "film":
22 with open(graph_node_features_and_labels_file_path) as graph_node_features_and_labels_file:
23 graph_node_features_and_labels_file.readline()
24 for line in graph_node_features_and_labels_file:
25 line = line.rstrip().split("\t")
26 assert len(line) == 3
27 assert int(line[0]) not in graph_node_features_dict and int(line[0]) not in graph_labels_dict
28 feature_blank = np.zeros(932, dtype=np.uint8)
29 feature_blank[np.array(line[1].split(","), dtype=np.uint16)] = 1
30 graph_node_features_dict[int(line[0])] = feature_blank
31 graph_labels_dict[int(line[0])] = int(line[2])
32 else:
33 with open(graph_node_features_and_labels_file_path) as graph_node_features_and_labels_file:
34 graph_node_features_and_labels_file.readline()
35 for line in graph_node_features_and_labels_file:
36 line = line.rstrip().split("\t")
37 assert len(line) == 3
38 assert int(line[0]) not in graph_node_features_dict and int(line[0]) not in graph_labels_dict
39 graph_node_features_dict[int(line[0])] = np.array(line[1].split(","), dtype=np.uint8)
40 graph_labels_dict[int(line[0])] = int(line[2])
41
42 with open(graph_adjacency_list_file_path) as graph_adjacency_list_file:
43 graph_adjacency_list_file.readline()
44 for line in graph_adjacency_list_file:
45 line = line.rstrip().split("\t")
46 assert len(line) == 2
47 if int(line[0]) not in G:
48 G.add_node(
49 int(line[0]), features=graph_node_features_dict[int(line[0])], label=graph_labels_dict[int(line[0])]
50 )
51 if int(line[1]) not in G:
52 G.add_node(
53 int(line[1]), features=graph_node_features_dict[int(line[1])], label=graph_labels_dict[int(line[1])]
54 )
55 G.add_edge(int(line[0]), int(line[1]))
56
57 adj = nx.adjacency_matrix(G, sorted(G.nodes()))
58 features = np.array([features for _, features in sorted(G.nodes(data="features"), key=lambda x: x[0])])
59 labels = np.array([label for _, label in sorted(G.nodes(data="label"), key=lambda x: x[0])])
60
61 all_masks = []
62 for split in range(10):
63 graph_split_file_path = osp.join(folder, f"{dataset_name}_split_0.6_0.2_{split}.npz")
64 with np.load(graph_split_file_path) as splits_file:
65 train_mask = splits_file["train_mask"]
66 val_mask = splits_file["val_mask"]
67 test_mask = splits_file["test_mask"]
68 train_mask = torch.BoolTensor(train_mask)
69 val_mask = torch.BoolTensor(val_mask)
70 test_mask = torch.BoolTensor(test_mask)

Callers 1

processMethod · 0.85

Calls 2

GraphClass · 0.90
nodesMethod · 0.80

Tested by

no test coverage detected