(_)
| 309 | |
| 310 | |
| 311 | def main(_): |
| 312 | |
| 313 | |
| 314 | tf.logging.set_verbosity(tf.logging.INFO) |
| 315 | |
| 316 | graph = tf.Graph() |
| 317 | with graph.as_default(), tf.device('/cpu:0'): |
| 318 | ###################### |
| 319 | # Config model_deploy# |
| 320 | ###################### |
| 321 | |
| 322 | # required from data |
| 323 | num_samples_per_epoch = train_data['mouth'].shape[0] |
| 324 | num_batches_per_epoch = int(num_samples_per_epoch / FLAGS.batch_size) |
| 325 | |
| 326 | num_samples_per_epoch_test = test_data['mouth'].shape[0] |
| 327 | num_batches_per_epoch_test = int(num_samples_per_epoch_test / FLAGS.batch_size) |
| 328 | |
| 329 | # Create global_step |
| 330 | global_step = tf.Variable(0, name='global_step', trainable=False) |
| 331 | |
| 332 | ######################################### |
| 333 | # Configure the larning rate. # |
| 334 | ######################################### |
| 335 | learning_rate = _configure_learning_rate(num_samples_per_epoch, global_step) |
| 336 | opt = _configure_optimizer(learning_rate) |
| 337 | |
| 338 | ###################### |
| 339 | # Select the network # |
| 340 | ###################### |
| 341 | is_training = tf.placeholder(tf.bool) |
| 342 | |
| 343 | network_speech_fn = nets_factory.get_network_fn( |
| 344 | FLAGS.model_speech_name, |
| 345 | num_classes=2, |
| 346 | weight_decay=FLAGS.weight_decay, |
| 347 | is_training=is_training) |
| 348 | |
| 349 | network_mouth_fn = nets_factory.get_network_fn( |
| 350 | FLAGS.model_mouth_name, |
| 351 | num_classes=2, |
| 352 | weight_decay=FLAGS.weight_decay, |
| 353 | is_training=is_training) |
| 354 | |
| 355 | ##################################### |
| 356 | # Select the preprocessing function # |
| 357 | ##################################### |
| 358 | |
| 359 | # TODO: Do some preprocessing if necessary. |
| 360 | |
| 361 | ############################################################## |
| 362 | # Create a dataset provider that loads data from the dataset # |
| 363 | ############################################################## |
| 364 | # with tf.device(deploy_config.inputs_device()): |
| 365 | """ |
| 366 | Define the place holders and creating the batch tensor. |
| 367 | """ |
| 368 |
nothing calls this directly
no test coverage detected