MCPcopy
hub / github.com/tkipf/gcn / load_data

Function load_data

gcn/utils.py:24–90  ·  view source on GitHub ↗

Loads input data from gcn/data directory ind.dataset_str.x => the feature vectors of the training instances as scipy.sparse.csr.csr_matrix object; ind.dataset_str.tx => the feature vectors of the test instances as scipy.sparse.csr.csr_matrix object; ind.dataset_str.allx => the feat

(dataset_str)

Source from the content-addressed store, hash-verified

22
23
24def load_data(dataset_str):
25 """
26 Loads input data from gcn/data directory
27
28 ind.dataset_str.x => the feature vectors of the training instances as scipy.sparse.csr.csr_matrix object;
29 ind.dataset_str.tx => the feature vectors of the test instances as scipy.sparse.csr.csr_matrix object;
30 ind.dataset_str.allx => the feature vectors of both labeled and unlabeled training instances
31 (a superset of ind.dataset_str.x) as scipy.sparse.csr.csr_matrix object;
32 ind.dataset_str.y => the one-hot labels of the labeled training instances as numpy.ndarray object;
33 ind.dataset_str.ty => the one-hot labels of the test instances as numpy.ndarray object;
34 ind.dataset_str.ally => the labels for instances in ind.dataset_str.allx as numpy.ndarray object;
35 ind.dataset_str.graph => a dict in the format {index: [index_of_neighbor_nodes]} as collections.defaultdict
36 object;
37 ind.dataset_str.test.index => the indices of test instances in graph, for the inductive setting as list object.
38
39 All objects above must be saved using python pickle module.
40
41 :param dataset_str: Dataset name
42 :return: All data input files loaded (as well the training/test data).
43 """
44 names = ['x', 'y', 'tx', 'ty', 'allx', 'ally', 'graph']
45 objects = []
46 for i in range(len(names)):
47 with open("data/ind.{}.{}".format(dataset_str, names[i]), 'rb') as f:
48 if sys.version_info > (3, 0):
49 objects.append(pkl.load(f, encoding='latin1'))
50 else:
51 objects.append(pkl.load(f))
52
53 x, y, tx, ty, allx, ally, graph = tuple(objects)
54 test_idx_reorder = parse_index_file("data/ind.{}.test.index".format(dataset_str))
55 test_idx_range = np.sort(test_idx_reorder)
56
57 if dataset_str == 'citeseer':
58 # Fix citeseer dataset (there are some isolated nodes in the graph)
59 # Find isolated nodes, add them as zero-vecs into the right position
60 test_idx_range_full = range(min(test_idx_reorder), max(test_idx_reorder)+1)
61 tx_extended = sp.lil_matrix((len(test_idx_range_full), x.shape[1]))
62 tx_extended[test_idx_range-min(test_idx_range), :] = tx
63 tx = tx_extended
64 ty_extended = np.zeros((len(test_idx_range_full), y.shape[1]))
65 ty_extended[test_idx_range-min(test_idx_range), :] = ty
66 ty = ty_extended
67
68 features = sp.vstack((allx, tx)).tolil()
69 features[test_idx_reorder, :] = features[test_idx_range, :]
70 adj = nx.adjacency_matrix(nx.from_dict_of_lists(graph))
71
72 labels = np.vstack((ally, ty))
73 labels[test_idx_reorder, :] = labels[test_idx_range, :]
74
75 idx_test = test_idx_range.tolist()
76 idx_train = range(len(y))
77 idx_val = range(len(y), len(y)+500)
78
79 train_mask = sample_mask(idx_train, labels.shape[0])
80 val_mask = sample_mask(idx_val, labels.shape[0])
81 test_mask = sample_mask(idx_test, labels.shape[0])

Callers 1

train.pyFile · 0.85

Calls 3

parse_index_fileFunction · 0.85
sample_maskFunction · 0.85
loadMethod · 0.80

Tested by

no test coverage detected