(_)
| 93 | |
| 94 | |
| 95 | def main(_): |
| 96 | tf.logging.set_verbosity(tf.logging.INFO) |
| 97 | |
| 98 | news_config = GroverConfig.from_json_file(FLAGS.config_file) |
| 99 | |
| 100 | tf.gfile.MakeDirs(FLAGS.output_dir) |
| 101 | |
| 102 | input_files = [] |
| 103 | for input_pattern in FLAGS.input_file.split(","): |
| 104 | input_files.extend(tf.gfile.Glob(input_pattern)) |
| 105 | |
| 106 | tf.logging.info("*** Input Files ***") |
| 107 | for input_file in input_files: |
| 108 | tf.logging.info(" %s" % input_file) |
| 109 | |
| 110 | tpu_cluster_resolver = None |
| 111 | if FLAGS.use_tpu and FLAGS.tpu_name: |
| 112 | tpu_cluster_resolver = tf.contrib.cluster_resolver.TPUClusterResolver( |
| 113 | FLAGS.tpu_name, zone=FLAGS.tpu_zone, project=FLAGS.gcp_project) |
| 114 | |
| 115 | is_per_host = tf.contrib.tpu.InputPipelineConfig.PER_HOST_V2 |
| 116 | run_config = tf.contrib.tpu.RunConfig( |
| 117 | cluster=tpu_cluster_resolver, |
| 118 | master=FLAGS.master, |
| 119 | model_dir=FLAGS.output_dir, |
| 120 | save_checkpoints_steps=FLAGS.save_checkpoints_steps, |
| 121 | keep_checkpoint_max=None, |
| 122 | tpu_config=tf.contrib.tpu.TPUConfig( |
| 123 | iterations_per_loop=FLAGS.iterations_per_loop, |
| 124 | num_shards=FLAGS.num_tpu_cores, |
| 125 | per_host_input_for_training=is_per_host)) |
| 126 | |
| 127 | model_fn = model_fn_builder(news_config, init_checkpoint=FLAGS.init_checkpoint, |
| 128 | learning_rate=FLAGS.learning_rate, |
| 129 | num_train_steps=FLAGS.num_train_steps, |
| 130 | num_warmup_steps=FLAGS.num_warmup_steps, |
| 131 | use_tpu=FLAGS.use_tpu, |
| 132 | ) |
| 133 | |
| 134 | # If TPU is not available, this will fall back to normal Estimator on CPU |
| 135 | # or GPU. |
| 136 | estimator = tf.contrib.tpu.TPUEstimator( |
| 137 | use_tpu=FLAGS.use_tpu, |
| 138 | model_fn=model_fn, |
| 139 | config=run_config, |
| 140 | train_batch_size=FLAGS.train_batch_size, |
| 141 | eval_batch_size=FLAGS.train_batch_size, |
| 142 | params={'model_dir': FLAGS.output_dir} |
| 143 | ) |
| 144 | |
| 145 | tf.logging.info("***** Running training *****") |
| 146 | tf.logging.info(" Batch size = %d", FLAGS.train_batch_size) |
| 147 | train_input_fn = input_fn_builder( |
| 148 | input_files=input_files, |
| 149 | seq_length=FLAGS.max_seq_length, |
| 150 | is_training=True) |
| 151 | |
| 152 | estimator.train(input_fn=train_input_fn, max_steps=FLAGS.num_train_steps) |
nothing calls this directly
no test coverage detected