MCPcopy Index your code
hub / github.com/williamleif/GraphSAGE / train

Function train

graphsage/unsupervised_train.py:132–372  ·  view source on GitHub ↗
(train_data, test_data=None)

Source from the content-addressed store, hash-verified

130 return placeholders
131
132def train(train_data, test_data=None):
133 G = train_data[0]
134 features = train_data[1]
135 id_map = train_data[2]
136
137 if not features is None:
138 # pad with dummy zero vector
139 features = np.vstack([features, np.zeros((features.shape[1],))])
140
141 context_pairs = train_data[3] if FLAGS.random_context else None
142 placeholders = construct_placeholders()
143 minibatch = EdgeMinibatchIterator(G,
144 id_map,
145 placeholders, batch_size=FLAGS.batch_size,
146 max_degree=FLAGS.max_degree,
147 num_neg_samples=FLAGS.neg_sample_size,
148 context_pairs = context_pairs)
149 adj_info_ph = tf.placeholder(tf.int32, shape=minibatch.adj.shape)
150 adj_info = tf.Variable(adj_info_ph, trainable=False, name="adj_info")
151
152 if FLAGS.model == 'graphsage_mean':
153 # Create model
154 sampler = UniformNeighborSampler(adj_info)
155 layer_infos = [SAGEInfo("node", sampler, FLAGS.samples_1, FLAGS.dim_1),
156 SAGEInfo("node", sampler, FLAGS.samples_2, FLAGS.dim_2)]
157
158 model = SampleAndAggregate(placeholders,
159 features,
160 adj_info,
161 minibatch.deg,
162 layer_infos=layer_infos,
163 model_size=FLAGS.model_size,
164 identity_dim = FLAGS.identity_dim,
165 logging=True)
166 elif FLAGS.model == 'gcn':
167 # Create model
168 sampler = UniformNeighborSampler(adj_info)
169 layer_infos = [SAGEInfo("node", sampler, FLAGS.samples_1, 2*FLAGS.dim_1),
170 SAGEInfo("node", sampler, FLAGS.samples_2, 2*FLAGS.dim_2)]
171
172 model = SampleAndAggregate(placeholders,
173 features,
174 adj_info,
175 minibatch.deg,
176 layer_infos=layer_infos,
177 aggregator_type="gcn",
178 model_size=FLAGS.model_size,
179 identity_dim = FLAGS.identity_dim,
180 concat=False,
181 logging=True)
182
183 elif FLAGS.model == 'graphsage_seq':
184 sampler = UniformNeighborSampler(adj_info)
185 layer_infos = [SAGEInfo("node", sampler, FLAGS.samples_1, FLAGS.dim_1),
186 SAGEInfo("node", sampler, FLAGS.samples_2, FLAGS.dim_2)]
187
188 model = SampleAndAggregate(placeholders,
189 features,

Callers 1

mainFunction · 0.70

Calls 12

shuffleMethod · 0.95
endMethod · 0.95
SampleAndAggregateClass · 0.90
Node2VecModelClass · 0.90
run_random_walksFunction · 0.90
save_val_embeddingsFunction · 0.85
construct_placeholdersFunction · 0.70
log_dirFunction · 0.70
evaluateFunction · 0.70

Tested by

no test coverage detected