MCPcopy
hub / github.com/astorfi/lip-reading-deeplearning / main

Function main

code/training_evaluation/train.py:311–661  ·  view source on GitHub ↗
(_)

Source from the content-addressed store, hash-verified

309
310
311def main(_):
312
313
314 tf.logging.set_verbosity(tf.logging.INFO)
315
316 graph = tf.Graph()
317 with graph.as_default(), tf.device('/cpu:0'):
318 ######################
319 # Config model_deploy#
320 ######################
321
322 # required from data
323 num_samples_per_epoch = train_data['mouth'].shape[0]
324 num_batches_per_epoch = int(num_samples_per_epoch / FLAGS.batch_size)
325
326 num_samples_per_epoch_test = test_data['mouth'].shape[0]
327 num_batches_per_epoch_test = int(num_samples_per_epoch_test / FLAGS.batch_size)
328
329 # Create global_step
330 global_step = tf.Variable(0, name='global_step', trainable=False)
331
332 #########################################
333 # Configure the larning rate. #
334 #########################################
335 learning_rate = _configure_learning_rate(num_samples_per_epoch, global_step)
336 opt = _configure_optimizer(learning_rate)
337
338 ######################
339 # Select the network #
340 ######################
341 is_training = tf.placeholder(tf.bool)
342
343 network_speech_fn = nets_factory.get_network_fn(
344 FLAGS.model_speech_name,
345 num_classes=2,
346 weight_decay=FLAGS.weight_decay,
347 is_training=is_training)
348
349 network_mouth_fn = nets_factory.get_network_fn(
350 FLAGS.model_mouth_name,
351 num_classes=2,
352 weight_decay=FLAGS.weight_decay,
353 is_training=is_training)
354
355 #####################################
356 # Select the preprocessing function #
357 #####################################
358
359 # TODO: Do some preprocessing if necessary.
360
361 ##############################################################
362 # Create a dataset provider that loads data from the dataset #
363 ##############################################################
364 # with tf.device(deploy_config.inputs_device()):
365 """
366 Define the place holders and creating the batch tensor.
367 """
368

Callers

nothing calls this directly

Calls 3

_configure_learning_rateFunction · 0.70
_configure_optimizerFunction · 0.70
average_gradientsFunction · 0.70

Tested by

no test coverage detected