MCPcopy Index your code
hub / github.com/Turing-Project/WriteGPT / main

Function main

LanguageNetwork/GPT2/train/train_tpu.py:95–152  ·  view source on GitHub ↗
(_)

Source from the content-addressed store, hash-verified

93
94
95def main(_):
96 tf.logging.set_verbosity(tf.logging.INFO)
97
98 news_config = GroverConfig.from_json_file(FLAGS.config_file)
99
100 tf.gfile.MakeDirs(FLAGS.output_dir)
101
102 input_files = []
103 for input_pattern in FLAGS.input_file.split(","):
104 input_files.extend(tf.gfile.Glob(input_pattern))
105
106 tf.logging.info("*** Input Files ***")
107 for input_file in input_files:
108 tf.logging.info(" %s" % input_file)
109
110 tpu_cluster_resolver = None
111 if FLAGS.use_tpu and FLAGS.tpu_name:
112 tpu_cluster_resolver = tf.contrib.cluster_resolver.TPUClusterResolver(
113 FLAGS.tpu_name, zone=FLAGS.tpu_zone, project=FLAGS.gcp_project)
114
115 is_per_host = tf.contrib.tpu.InputPipelineConfig.PER_HOST_V2
116 run_config = tf.contrib.tpu.RunConfig(
117 cluster=tpu_cluster_resolver,
118 master=FLAGS.master,
119 model_dir=FLAGS.output_dir,
120 save_checkpoints_steps=FLAGS.save_checkpoints_steps,
121 keep_checkpoint_max=None,
122 tpu_config=tf.contrib.tpu.TPUConfig(
123 iterations_per_loop=FLAGS.iterations_per_loop,
124 num_shards=FLAGS.num_tpu_cores,
125 per_host_input_for_training=is_per_host))
126
127 model_fn = model_fn_builder(news_config, init_checkpoint=FLAGS.init_checkpoint,
128 learning_rate=FLAGS.learning_rate,
129 num_train_steps=FLAGS.num_train_steps,
130 num_warmup_steps=FLAGS.num_warmup_steps,
131 use_tpu=FLAGS.use_tpu,
132 )
133
134 # If TPU is not available, this will fall back to normal Estimator on CPU
135 # or GPU.
136 estimator = tf.contrib.tpu.TPUEstimator(
137 use_tpu=FLAGS.use_tpu,
138 model_fn=model_fn,
139 config=run_config,
140 train_batch_size=FLAGS.train_batch_size,
141 eval_batch_size=FLAGS.train_batch_size,
142 params={'model_dir': FLAGS.output_dir}
143 )
144
145 tf.logging.info("***** Running training *****")
146 tf.logging.info(" Batch size = %d", FLAGS.train_batch_size)
147 train_input_fn = input_fn_builder(
148 input_files=input_files,
149 seq_length=FLAGS.max_seq_length,
150 is_training=True)
151
152 estimator.train(input_fn=train_input_fn, max_steps=FLAGS.num_train_steps)

Callers

nothing calls this directly

Calls 4

model_fn_builderFunction · 0.90
input_fn_builderFunction · 0.90
from_json_fileMethod · 0.45
trainMethod · 0.45

Tested by

no test coverage detected