MCPcopy
hub / github.com/williamleif/GraphSAGE / load_data

Function load_data

graphsage/utils.py:19–75  ·  view source on GitHub ↗
(prefix, normalize=True, load_walks=False)

Source from the content-addressed store, hash-verified

17N_WALKS=50
18
19def load_data(prefix, normalize=True, load_walks=False):
20 G_data = json.load(open(prefix + "-G.json"))
21 G = json_graph.node_link_graph(G_data)
22 if isinstance(G.nodes()[0], int):
23 conversion = lambda n : int(n)
24 else:
25 conversion = lambda n : n
26
27 if os.path.exists(prefix + "-feats.npy"):
28 feats = np.load(prefix + "-feats.npy")
29 else:
30 print("No features present.. Only identity features will be used.")
31 feats = None
32 id_map = json.load(open(prefix + "-id_map.json"))
33 id_map = {conversion(k):int(v) for k,v in id_map.items()}
34 walks = []
35 class_map = json.load(open(prefix + "-class_map.json"))
36 if isinstance(list(class_map.values())[0], list):
37 lab_conversion = lambda n : n
38 else:
39 lab_conversion = lambda n : int(n)
40
41 class_map = {conversion(k):lab_conversion(v) for k,v in class_map.items()}
42
43 ## Remove all nodes that do not have val/test annotations
44 ## (necessary because of networkx weirdness with the Reddit data)
45 broken_count = 0
46 for node in G.nodes():
47 if not 'val' in G.node[node] or not 'test' in G.node[node]:
48 G.remove_node(node)
49 broken_count += 1
50 print("Removed {:d} nodes that lacked proper annotations due to networkx versioning issues".format(broken_count))
51
52 ## Make sure the graph has edge train_removed annotations
53 ## (some datasets might already have this..)
54 print("Loaded data.. now preprocessing..")
55 for edge in G.edges():
56 if (G.node[edge[0]]['val'] or G.node[edge[1]]['val'] or
57 G.node[edge[0]]['test'] or G.node[edge[1]]['test']):
58 G[edge[0]][edge[1]]['train_removed'] = True
59 else:
60 G[edge[0]][edge[1]]['train_removed'] = False
61
62 if normalize and not feats is None:
63 from sklearn.preprocessing import StandardScaler
64 train_ids = np.array([id_map[n] for n in G.nodes() if not G.node[n]['val'] and not G.node[n]['test']])
65 train_feats = feats[train_ids]
66 scaler = StandardScaler()
67 scaler.fit(train_feats)
68 feats = scaler.transform(feats)
69
70 if load_walks:
71 with open(prefix + "-walks.txt") as fp:
72 for line in fp:
73 walks.append(map(conversion, line.split()))
74
75 return G, feats, id_map, walks, class_map
76

Callers 2

mainFunction · 0.90
mainFunction · 0.90

Calls 1

loadMethod · 0.80

Tested by

no test coverage detected