(train_data, test_data=None)
| 130 | return placeholders |
| 131 | |
| 132 | def train(train_data, test_data=None): |
| 133 | G = train_data[0] |
| 134 | features = train_data[1] |
| 135 | id_map = train_data[2] |
| 136 | |
| 137 | if not features is None: |
| 138 | # pad with dummy zero vector |
| 139 | features = np.vstack([features, np.zeros((features.shape[1],))]) |
| 140 | |
| 141 | context_pairs = train_data[3] if FLAGS.random_context else None |
| 142 | placeholders = construct_placeholders() |
| 143 | minibatch = EdgeMinibatchIterator(G, |
| 144 | id_map, |
| 145 | placeholders, batch_size=FLAGS.batch_size, |
| 146 | max_degree=FLAGS.max_degree, |
| 147 | num_neg_samples=FLAGS.neg_sample_size, |
| 148 | context_pairs = context_pairs) |
| 149 | adj_info_ph = tf.placeholder(tf.int32, shape=minibatch.adj.shape) |
| 150 | adj_info = tf.Variable(adj_info_ph, trainable=False, name="adj_info") |
| 151 | |
| 152 | if FLAGS.model == 'graphsage_mean': |
| 153 | # Create model |
| 154 | sampler = UniformNeighborSampler(adj_info) |
| 155 | layer_infos = [SAGEInfo("node", sampler, FLAGS.samples_1, FLAGS.dim_1), |
| 156 | SAGEInfo("node", sampler, FLAGS.samples_2, FLAGS.dim_2)] |
| 157 | |
| 158 | model = SampleAndAggregate(placeholders, |
| 159 | features, |
| 160 | adj_info, |
| 161 | minibatch.deg, |
| 162 | layer_infos=layer_infos, |
| 163 | model_size=FLAGS.model_size, |
| 164 | identity_dim = FLAGS.identity_dim, |
| 165 | logging=True) |
| 166 | elif FLAGS.model == 'gcn': |
| 167 | # Create model |
| 168 | sampler = UniformNeighborSampler(adj_info) |
| 169 | layer_infos = [SAGEInfo("node", sampler, FLAGS.samples_1, 2*FLAGS.dim_1), |
| 170 | SAGEInfo("node", sampler, FLAGS.samples_2, 2*FLAGS.dim_2)] |
| 171 | |
| 172 | model = SampleAndAggregate(placeholders, |
| 173 | features, |
| 174 | adj_info, |
| 175 | minibatch.deg, |
| 176 | layer_infos=layer_infos, |
| 177 | aggregator_type="gcn", |
| 178 | model_size=FLAGS.model_size, |
| 179 | identity_dim = FLAGS.identity_dim, |
| 180 | concat=False, |
| 181 | logging=True) |
| 182 | |
| 183 | elif FLAGS.model == 'graphsage_seq': |
| 184 | sampler = UniformNeighborSampler(adj_info) |
| 185 | layer_infos = [SAGEInfo("node", sampler, FLAGS.samples_1, FLAGS.dim_1), |
| 186 | SAGEInfo("node", sampler, FLAGS.samples_2, FLAGS.dim_2)] |
| 187 | |
| 188 | model = SampleAndAggregate(placeholders, |
| 189 | features, |
no test coverage detected