()
| 91 | return bn_decay |
| 92 | |
| 93 | def train(): |
| 94 | with tf.Graph().as_default(): |
| 95 | with tf.device('/gpu:'+str(GPU_INDEX)): |
| 96 | pointclouds_pl, labels_pl = MODEL.placeholder_inputs(BATCH_SIZE, NUM_POINT) |
| 97 | is_training_pl = tf.placeholder(tf.bool, shape=()) |
| 98 | print(is_training_pl) |
| 99 | |
| 100 | # Note the global_step=batch parameter to minimize. |
| 101 | # That tells the optimizer to helpfully increment the 'batch' parameter for you every time it trains. |
| 102 | batch = tf.Variable(0) |
| 103 | bn_decay = get_bn_decay(batch) |
| 104 | tf.summary.scalar('bn_decay', bn_decay) |
| 105 | |
| 106 | # Get model and loss |
| 107 | pred, end_points = MODEL.get_model(pointclouds_pl, is_training_pl, bn_decay=bn_decay) |
| 108 | loss = MODEL.get_loss(pred, labels_pl, end_points) |
| 109 | tf.summary.scalar('loss', loss) |
| 110 | |
| 111 | correct = tf.equal(tf.argmax(pred, 1), tf.to_int64(labels_pl)) |
| 112 | accuracy = tf.reduce_sum(tf.cast(correct, tf.float32)) / float(BATCH_SIZE) |
| 113 | tf.summary.scalar('accuracy', accuracy) |
| 114 | |
| 115 | # Get training operator |
| 116 | learning_rate = get_learning_rate(batch) |
| 117 | tf.summary.scalar('learning_rate', learning_rate) |
| 118 | if OPTIMIZER == 'momentum': |
| 119 | optimizer = tf.train.MomentumOptimizer(learning_rate, momentum=MOMENTUM) |
| 120 | elif OPTIMIZER == 'adam': |
| 121 | optimizer = tf.train.AdamOptimizer(learning_rate) |
| 122 | train_op = optimizer.minimize(loss, global_step=batch) |
| 123 | |
| 124 | # Add ops to save and restore all the variables. |
| 125 | saver = tf.train.Saver() |
| 126 | |
| 127 | # Create a session |
| 128 | config = tf.ConfigProto() |
| 129 | config.gpu_options.allow_growth = True |
| 130 | config.allow_soft_placement = True |
| 131 | config.log_device_placement = False |
| 132 | sess = tf.Session(config=config) |
| 133 | |
| 134 | # Add summary writers |
| 135 | #merged = tf.merge_all_summaries() |
| 136 | merged = tf.summary.merge_all() |
| 137 | train_writer = tf.summary.FileWriter(os.path.join(LOG_DIR, 'train'), |
| 138 | sess.graph) |
| 139 | test_writer = tf.summary.FileWriter(os.path.join(LOG_DIR, 'test')) |
| 140 | |
| 141 | # Init variables |
| 142 | init = tf.global_variables_initializer() |
| 143 | # To fix the bug introduced in TF 0.12.1 as in |
| 144 | # http://stackoverflow.com/questions/41543774/invalidargumenterror-for-tensor-bool-tensorflow-0-12-1 |
| 145 | #sess.run(init) |
| 146 | sess.run(init, {is_training_pl: True}) |
| 147 | |
| 148 | ops = {'pointclouds_pl': pointclouds_pl, |
| 149 | 'labels_pl': labels_pl, |
| 150 | 'is_training_pl': is_training_pl, |
no test coverage detected