Summarize the moving average for scalar tensors. This function is a no-op if not calling from main training tower. See tutorial at https://tensorpack.readthedocs.io/tutorial/summary.html Args: args: scalar tensors to summarize decay (float): the decay rate. Defaults
(*args, **kwargs)
| 196 | |
| 197 | |
| 198 | def add_moving_summary(*args, **kwargs): |
| 199 | """ |
| 200 | Summarize the moving average for scalar tensors. |
| 201 | This function is a no-op if not calling from main training tower. |
| 202 | See tutorial at https://tensorpack.readthedocs.io/tutorial/summary.html |
| 203 | |
| 204 | Args: |
| 205 | args: scalar tensors to summarize |
| 206 | decay (float): the decay rate. Defaults to 0.95. |
| 207 | collection (str or None): the name of the collection to add EMA-maintaining ops. |
| 208 | The default will work together with the default |
| 209 | :class:`MovingAverageSummary` callback. |
| 210 | summary_collections ([str]): the names of collections to add the |
| 211 | summary op. Default is TF's default (`tf.GraphKeys.SUMMARIES`). |
| 212 | |
| 213 | Returns: |
| 214 | [tf.Tensor]: |
| 215 | list of tensors returned by assign_moving_average, |
| 216 | which can be used to maintain the EMA. |
| 217 | """ |
| 218 | decay = kwargs.pop('decay', 0.95) |
| 219 | coll = kwargs.pop('collection', MOVING_SUMMARY_OPS_KEY) |
| 220 | summ_coll = kwargs.pop('summary_collections', None) |
| 221 | assert len(kwargs) == 0, "Unknown arguments: " + str(kwargs) |
| 222 | |
| 223 | ctx = get_current_tower_context() |
| 224 | # allow ctx to be none |
| 225 | if ctx is not None and not ctx.is_main_training_tower: |
| 226 | return [] |
| 227 | |
| 228 | graph = tf.get_default_graph() |
| 229 | try: |
| 230 | control_flow_ctx = graph._get_control_flow_context() |
| 231 | # XLA does not support summaries anyway |
| 232 | # However, this function will generate unnecessary dependency edges, |
| 233 | # which makes the tower function harder to compile under XLA, so we skip it |
| 234 | if control_flow_ctx is not None and control_flow_ctx.IsXLAContext(): |
| 235 | return |
| 236 | except Exception: |
| 237 | pass |
| 238 | |
| 239 | if tf.get_variable_scope().reuse is True: |
| 240 | logger.warn("add_moving_summary() called under reuse=True scope, ignored.") |
| 241 | return [] |
| 242 | |
| 243 | for x in args: |
| 244 | assert isinstance(x, (tf.Tensor, tf.Variable)), x |
| 245 | assert x.get_shape().ndims == 0, \ |
| 246 | "add_moving_summary() only accepts scalar tensor! Got one with {}".format(x.get_shape()) |
| 247 | |
| 248 | from ..graph_builder.utils import override_to_local_variable |
| 249 | ema_ops = [] |
| 250 | for c in args: |
| 251 | name = re.sub('tower[0-9]+/', '', c.op.name) |
| 252 | with tf.name_scope(None), override_to_local_variable(True): |
| 253 | if not c.dtype.is_floating: |
| 254 | c = tf.cast(c, tf.float32) |
| 255 | # assign_moving_average creates variables with op names, therefore clear ns first. |
no test coverage detected