MCPcopy
hub / github.com/charlesq34/pointnet / train_one_epoch

Function train_one_epoch

train.py:171–213  ·  view source on GitHub ↗

ops: dict mapping from string to tf ops

(sess, ops, train_writer)

Source from the content-addressed store, hash-verified

169
170
171def train_one_epoch(sess, ops, train_writer):
172 """ ops: dict mapping from string to tf ops """
173 is_training = True
174
175 # Shuffle train files
176 train_file_idxs = np.arange(0, len(TRAIN_FILES))
177 np.random.shuffle(train_file_idxs)
178
179 for fn in range(len(TRAIN_FILES)):
180 log_string('----' + str(fn) + '-----')
181 current_data, current_label = provider.loadDataFile(TRAIN_FILES[train_file_idxs[fn]])
182 current_data = current_data[:,0:NUM_POINT,:]
183 current_data, current_label, _ = provider.shuffle_data(current_data, np.squeeze(current_label))
184 current_label = np.squeeze(current_label)
185
186 file_size = current_data.shape[0]
187 num_batches = file_size // BATCH_SIZE
188
189 total_correct = 0
190 total_seen = 0
191 loss_sum = 0
192
193 for batch_idx in range(num_batches):
194 start_idx = batch_idx * BATCH_SIZE
195 end_idx = (batch_idx+1) * BATCH_SIZE
196
197 # Augment batched point clouds by rotation and jittering
198 rotated_data = provider.rotate_point_cloud(current_data[start_idx:end_idx, :, :])
199 jittered_data = provider.jitter_point_cloud(rotated_data)
200 feed_dict = {ops['pointclouds_pl']: jittered_data,
201 ops['labels_pl']: current_label[start_idx:end_idx],
202 ops['is_training_pl']: is_training,}
203 summary, step, _, loss_val, pred_val = sess.run([ops['merged'], ops['step'],
204 ops['train_op'], ops['loss'], ops['pred']], feed_dict=feed_dict)
205 train_writer.add_summary(summary, step)
206 pred_val = np.argmax(pred_val, 1)
207 correct = np.sum(pred_val == current_label[start_idx:end_idx])
208 total_correct += correct
209 total_seen += BATCH_SIZE
210 loss_sum += loss_val
211
212 log_string('mean loss: %f' % (loss_sum / float(num_batches)))
213 log_string('accuracy: %f' % (total_correct / float(total_seen)))
214
215
216def eval_one_epoch(sess, ops, test_writer):

Callers 1

trainFunction · 0.70

Calls 1

log_stringFunction · 0.70

Tested by

no test coverage detected