MCPcopy Index your code
hub / github.com/Turing-Project/WriteGPT / main

Function main

LanguageNetwork/GPT2/train/train_wc.py:104–163  ·  view source on GitHub ↗
(_)

Source from the content-addressed store, hash-verified

102
103
104def main(_):
105 tf.logging.set_verbosity(tf.logging.INFO)
106
107 news_config = GroverConfig.from_json_file(FLAGS.config_file)
108
109 tf.gfile.MakeDirs(FLAGS.output_dir)
110
111 input_files = []
112 print(FLAGS.input_file.split(","))
113 for input_pattern in FLAGS.input_file.split(","):
114 input_files.extend(tf.gfile.Glob(input_pattern))
115
116 tf.logging.info("*** Input Files ***")
117 for input_file in input_files:
118 tf.logging.info(" %s" % input_file)
119
120 tpu_cluster_resolver = None
121 if FLAGS.use_tpu and FLAGS.tpu_name:
122 tpu_cluster_resolver = tf.contrib.cluster_resolver.TPUClusterResolver(
123 FLAGS.tpu_name, zone=FLAGS.tpu_zone, project=FLAGS.gcp_project)
124
125 is_per_host = tf.contrib.tpu.InputPipelineConfig.PER_HOST_V2
126 run_config = tf.contrib.tpu.RunConfig(
127 cluster=tpu_cluster_resolver,
128 master=FLAGS.master,
129 model_dir=FLAGS.output_dir,
130 save_checkpoints_steps=FLAGS.save_checkpoints_steps,
131 keep_checkpoint_max=None,
132 tpu_config=tf.contrib.tpu.TPUConfig(
133 iterations_per_loop=FLAGS.iterations_per_loop,
134 num_shards=FLAGS.num_tpu_cores,
135 per_host_input_for_training=is_per_host))
136
137 model_fn = model_fn_builder(news_config, init_checkpoint=FLAGS.init_checkpoint,
138 learning_rate=FLAGS.learning_rate,
139 num_train_steps=FLAGS.num_train_steps,
140 num_warmup_steps=FLAGS.num_warmup_steps,
141 use_tpu=FLAGS.use_tpu,
142 )
143
144 # # If TPU is not available, this will fall back to normal Estimator on CPU
145 # # or GPU.
146 estimator = tf.contrib.tpu.TPUEstimator(
147 use_tpu=FLAGS.use_tpu,
148 model_fn=model_fn,
149 config=run_config,
150 train_batch_size=FLAGS.train_batch_size,
151 eval_batch_size=FLAGS.train_batch_size,
152 params={'model_dir': FLAGS.output_dir}
153 )
154
155 tf.logging.info("***** Running training *****")
156 tf.logging.info(" Batch size = %d", FLAGS.train_batch_size)
157 train_input_fn = input_fn_builder(
158 input_files=input_files,
159 seq_length=FLAGS.max_seq_length,
160 is_training=True)
161

Callers

nothing calls this directly

Calls 4

model_fn_builderFunction · 0.90
input_fn_builderFunction · 0.90
from_json_fileMethod · 0.45
trainMethod · 0.45

Tested by

no test coverage detected