Construct a host call to log scalars when training on TPU. Args: metric_dict: A dict of the tensors to be logged. model_dir: The location to write the summary. prefix: The prefix (if any) to prepend to the metric names. Returns: A tuple of (function, args_to_be_pass
(metric_dict, model_dir, prefix="")
| 180 | |
| 181 | |
| 182 | def construct_scalar_host_call(metric_dict, model_dir, prefix=""): |
| 183 | """Construct a host call to log scalars when training on TPU. |
| 184 | |
| 185 | Args: |
| 186 | metric_dict: A dict of the tensors to be logged. |
| 187 | model_dir: The location to write the summary. |
| 188 | prefix: The prefix (if any) to prepend to the metric names. |
| 189 | |
| 190 | Returns: |
| 191 | A tuple of (function, args_to_be_passed_to_said_function) |
| 192 | """ |
| 193 | metric_names = list(metric_dict.keys()) |
| 194 | |
| 195 | def host_call_fn(global_step, *args): |
| 196 | """Training host call. Creates scalar summaries for training metrics. |
| 197 | |
| 198 | This function is executed on the CPU and should not directly reference |
| 199 | any Tensors in the rest of the `model_fn`. To pass Tensors from the |
| 200 | model to the `metric_fn`, provide as part of the `host_call`. See |
| 201 | https://www.tensorflow.org/api_docs/python/tf/contrib/tpu/TPUEstimatorSpec |
| 202 | for more information. |
| 203 | |
| 204 | Arguments should match the list of `Tensor` objects passed as the second |
| 205 | element in the tuple passed to `host_call`. |
| 206 | |
| 207 | Args: |
| 208 | global_step: `Tensor with shape `[batch]` for the global_step |
| 209 | *args: Remaining tensors to log. |
| 210 | |
| 211 | Returns: |
| 212 | List of summary ops to run on the CPU host. |
| 213 | """ |
| 214 | step = global_step[0] |
| 215 | with tf.contrib.summary.create_file_writer( |
| 216 | logdir=model_dir, filename_suffix=".host_call").as_default(): |
| 217 | with tf.contrib.summary.always_record_summaries(): |
| 218 | for i, name in enumerate(metric_names): |
| 219 | tf.contrib.summary.scalar(prefix + name, args[i][0], step=step) |
| 220 | |
| 221 | return tf.contrib.summary.all_summary_ops() |
| 222 | |
| 223 | # To log the current learning rate, and gradient norm for Tensorboard, the |
| 224 | # summary op needs to be run on the host CPU via host_call. host_call |
| 225 | # expects [batch_size, ...] Tensors, thus reshape to introduce a batch |
| 226 | # dimension. These Tensors are implicitly concatenated to |
| 227 | # [params['batch_size']]. |
| 228 | global_step_tensor = tf.reshape( |
| 229 | tf.compat.v1.train.get_or_create_global_step(), [1]) |
| 230 | other_tensors = [tf.reshape(metric_dict[key], [1]) for key in metric_names] |
| 231 | |
| 232 | return host_call_fn, [global_step_tensor] + other_tensors |