MCPcopy
hub / github.com/charlesq34/pointnet / train

Function train

train.py:93–167  ·  view source on GitHub ↗
()

Source from the content-addressed store, hash-verified

91 return bn_decay
92
93def train():
94 with tf.Graph().as_default():
95 with tf.device('/gpu:'+str(GPU_INDEX)):
96 pointclouds_pl, labels_pl = MODEL.placeholder_inputs(BATCH_SIZE, NUM_POINT)
97 is_training_pl = tf.placeholder(tf.bool, shape=())
98 print(is_training_pl)
99
100 # Note the global_step=batch parameter to minimize.
101 # That tells the optimizer to helpfully increment the 'batch' parameter for you every time it trains.
102 batch = tf.Variable(0)
103 bn_decay = get_bn_decay(batch)
104 tf.summary.scalar('bn_decay', bn_decay)
105
106 # Get model and loss
107 pred, end_points = MODEL.get_model(pointclouds_pl, is_training_pl, bn_decay=bn_decay)
108 loss = MODEL.get_loss(pred, labels_pl, end_points)
109 tf.summary.scalar('loss', loss)
110
111 correct = tf.equal(tf.argmax(pred, 1), tf.to_int64(labels_pl))
112 accuracy = tf.reduce_sum(tf.cast(correct, tf.float32)) / float(BATCH_SIZE)
113 tf.summary.scalar('accuracy', accuracy)
114
115 # Get training operator
116 learning_rate = get_learning_rate(batch)
117 tf.summary.scalar('learning_rate', learning_rate)
118 if OPTIMIZER == 'momentum':
119 optimizer = tf.train.MomentumOptimizer(learning_rate, momentum=MOMENTUM)
120 elif OPTIMIZER == 'adam':
121 optimizer = tf.train.AdamOptimizer(learning_rate)
122 train_op = optimizer.minimize(loss, global_step=batch)
123
124 # Add ops to save and restore all the variables.
125 saver = tf.train.Saver()
126
127 # Create a session
128 config = tf.ConfigProto()
129 config.gpu_options.allow_growth = True
130 config.allow_soft_placement = True
131 config.log_device_placement = False
132 sess = tf.Session(config=config)
133
134 # Add summary writers
135 #merged = tf.merge_all_summaries()
136 merged = tf.summary.merge_all()
137 train_writer = tf.summary.FileWriter(os.path.join(LOG_DIR, 'train'),
138 sess.graph)
139 test_writer = tf.summary.FileWriter(os.path.join(LOG_DIR, 'test'))
140
141 # Init variables
142 init = tf.global_variables_initializer()
143 # To fix the bug introduced in TF 0.12.1 as in
144 # http://stackoverflow.com/questions/41543774/invalidargumenterror-for-tensor-bool-tensorflow-0-12-1
145 #sess.run(init)
146 sess.run(init, {is_training_pl: True})
147
148 ops = {'pointclouds_pl': pointclouds_pl,
149 'labels_pl': labels_pl,
150 'is_training_pl': is_training_pl,

Callers 1

train.pyFile · 0.70

Calls 5

get_bn_decayFunction · 0.70
get_learning_rateFunction · 0.70
log_stringFunction · 0.70
train_one_epochFunction · 0.70
eval_one_epochFunction · 0.70

Tested by

no test coverage detected