MCPcopy Index your code
hub / github.com/Turing-Project/WriteGPT / construct_scalar_host_call

Function construct_scalar_host_call

LanguageNetwork/GPT2/scripts/utils.py:182–232  ·  view source on GitHub ↗

Construct a host call to log scalars when training on TPU. Args: metric_dict: A dict of the tensors to be logged. model_dir: The location to write the summary. prefix: The prefix (if any) to prepend to the metric names. Returns: A tuple of (function, args_to_be_pass

(metric_dict, model_dir, prefix="")

Source from the content-addressed store, hash-verified

180
181
182def construct_scalar_host_call(metric_dict, model_dir, prefix=""):
183 """Construct a host call to log scalars when training on TPU.
184
185 Args:
186 metric_dict: A dict of the tensors to be logged.
187 model_dir: The location to write the summary.
188 prefix: The prefix (if any) to prepend to the metric names.
189
190 Returns:
191 A tuple of (function, args_to_be_passed_to_said_function)
192 """
193 metric_names = list(metric_dict.keys())
194
195 def host_call_fn(global_step, *args):
196 """Training host call. Creates scalar summaries for training metrics.
197
198 This function is executed on the CPU and should not directly reference
199 any Tensors in the rest of the `model_fn`. To pass Tensors from the
200 model to the `metric_fn`, provide as part of the `host_call`. See
201 https://www.tensorflow.org/api_docs/python/tf/contrib/tpu/TPUEstimatorSpec
202 for more information.
203
204 Arguments should match the list of `Tensor` objects passed as the second
205 element in the tuple passed to `host_call`.
206
207 Args:
208 global_step: `Tensor with shape `[batch]` for the global_step
209 *args: Remaining tensors to log.
210
211 Returns:
212 List of summary ops to run on the CPU host.
213 """
214 step = global_step[0]
215 with tf.contrib.summary.create_file_writer(
216 logdir=model_dir, filename_suffix=".host_call").as_default():
217 with tf.contrib.summary.always_record_summaries():
218 for i, name in enumerate(metric_names):
219 tf.contrib.summary.scalar(prefix + name, args[i][0], step=step)
220
221 return tf.contrib.summary.all_summary_ops()
222
223 # To log the current learning rate, and gradient norm for Tensorboard, the
224 # summary op needs to be run on the host CPU via host_call. host_call
225 # expects [batch_size, ...] Tensors, thus reshape to introduce a batch
226 # dimension. These Tensors are implicitly concatenated to
227 # [params['batch_size']].
228 global_step_tensor = tf.reshape(
229 tf.compat.v1.train.get_or_create_global_step(), [1])
230 other_tensors = [tf.reshape(metric_dict[key], [1]) for key in metric_names]
231
232 return host_call_fn, [global_step_tensor] + other_tensors

Callers 1

model_fnFunction · 0.90

Calls

no outgoing calls

Tested by

no test coverage detected