ops: dict mapping from string to tf ops
(sess, ops, train_writer)
| 169 | |
| 170 | |
| 171 | def train_one_epoch(sess, ops, train_writer): |
| 172 | """ ops: dict mapping from string to tf ops """ |
| 173 | is_training = True |
| 174 | |
| 175 | # Shuffle train files |
| 176 | train_file_idxs = np.arange(0, len(TRAIN_FILES)) |
| 177 | np.random.shuffle(train_file_idxs) |
| 178 | |
| 179 | for fn in range(len(TRAIN_FILES)): |
| 180 | log_string('----' + str(fn) + '-----') |
| 181 | current_data, current_label = provider.loadDataFile(TRAIN_FILES[train_file_idxs[fn]]) |
| 182 | current_data = current_data[:,0:NUM_POINT,:] |
| 183 | current_data, current_label, _ = provider.shuffle_data(current_data, np.squeeze(current_label)) |
| 184 | current_label = np.squeeze(current_label) |
| 185 | |
| 186 | file_size = current_data.shape[0] |
| 187 | num_batches = file_size // BATCH_SIZE |
| 188 | |
| 189 | total_correct = 0 |
| 190 | total_seen = 0 |
| 191 | loss_sum = 0 |
| 192 | |
| 193 | for batch_idx in range(num_batches): |
| 194 | start_idx = batch_idx * BATCH_SIZE |
| 195 | end_idx = (batch_idx+1) * BATCH_SIZE |
| 196 | |
| 197 | # Augment batched point clouds by rotation and jittering |
| 198 | rotated_data = provider.rotate_point_cloud(current_data[start_idx:end_idx, :, :]) |
| 199 | jittered_data = provider.jitter_point_cloud(rotated_data) |
| 200 | feed_dict = {ops['pointclouds_pl']: jittered_data, |
| 201 | ops['labels_pl']: current_label[start_idx:end_idx], |
| 202 | ops['is_training_pl']: is_training,} |
| 203 | summary, step, _, loss_val, pred_val = sess.run([ops['merged'], ops['step'], |
| 204 | ops['train_op'], ops['loss'], ops['pred']], feed_dict=feed_dict) |
| 205 | train_writer.add_summary(summary, step) |
| 206 | pred_val = np.argmax(pred_val, 1) |
| 207 | correct = np.sum(pred_val == current_label[start_idx:end_idx]) |
| 208 | total_correct += correct |
| 209 | total_seen += BATCH_SIZE |
| 210 | loss_sum += loss_val |
| 211 | |
| 212 | log_string('mean loss: %f' % (loss_sum / float(num_batches))) |
| 213 | log_string('accuracy: %f' % (total_correct / float(total_seen))) |
| 214 | |
| 215 | |
| 216 | def eval_one_epoch(sess, ops, test_writer): |