| 17 | N_WALKS=50 |
| 18 | |
| 19 | def load_data(prefix, normalize=True, load_walks=False): |
| 20 | G_data = json.load(open(prefix + "-G.json")) |
| 21 | G = json_graph.node_link_graph(G_data) |
| 22 | if isinstance(G.nodes()[0], int): |
| 23 | conversion = lambda n : int(n) |
| 24 | else: |
| 25 | conversion = lambda n : n |
| 26 | |
| 27 | if os.path.exists(prefix + "-feats.npy"): |
| 28 | feats = np.load(prefix + "-feats.npy") |
| 29 | else: |
| 30 | print("No features present.. Only identity features will be used.") |
| 31 | feats = None |
| 32 | id_map = json.load(open(prefix + "-id_map.json")) |
| 33 | id_map = {conversion(k):int(v) for k,v in id_map.items()} |
| 34 | walks = [] |
| 35 | class_map = json.load(open(prefix + "-class_map.json")) |
| 36 | if isinstance(list(class_map.values())[0], list): |
| 37 | lab_conversion = lambda n : n |
| 38 | else: |
| 39 | lab_conversion = lambda n : int(n) |
| 40 | |
| 41 | class_map = {conversion(k):lab_conversion(v) for k,v in class_map.items()} |
| 42 | |
| 43 | ## Remove all nodes that do not have val/test annotations |
| 44 | ## (necessary because of networkx weirdness with the Reddit data) |
| 45 | broken_count = 0 |
| 46 | for node in G.nodes(): |
| 47 | if not 'val' in G.node[node] or not 'test' in G.node[node]: |
| 48 | G.remove_node(node) |
| 49 | broken_count += 1 |
| 50 | print("Removed {:d} nodes that lacked proper annotations due to networkx versioning issues".format(broken_count)) |
| 51 | |
| 52 | ## Make sure the graph has edge train_removed annotations |
| 53 | ## (some datasets might already have this..) |
| 54 | print("Loaded data.. now preprocessing..") |
| 55 | for edge in G.edges(): |
| 56 | if (G.node[edge[0]]['val'] or G.node[edge[1]]['val'] or |
| 57 | G.node[edge[0]]['test'] or G.node[edge[1]]['test']): |
| 58 | G[edge[0]][edge[1]]['train_removed'] = True |
| 59 | else: |
| 60 | G[edge[0]][edge[1]]['train_removed'] = False |
| 61 | |
| 62 | if normalize and not feats is None: |
| 63 | from sklearn.preprocessing import StandardScaler |
| 64 | train_ids = np.array([id_map[n] for n in G.nodes() if not G.node[n]['val'] and not G.node[n]['test']]) |
| 65 | train_feats = feats[train_ids] |
| 66 | scaler = StandardScaler() |
| 67 | scaler.fit(train_feats) |
| 68 | feats = scaler.transform(feats) |
| 69 | |
| 70 | if load_walks: |
| 71 | with open(prefix + "-walks.txt") as fp: |
| 72 | for line in fp: |
| 73 | walks.append(map(conversion, line.split())) |
| 74 | |
| 75 | return G, feats, id_map, walks, class_map |
| 76 | |