Batch Renormalization layer, as described in the paper: `Batch Renormalization: Towards Reducing Minibatch Dependence in Batch-Normalized Models `_. This implementation is a wrapper around `tf.layers.batch_normalization`. Args: x (tf.Te
(x, rmax, dmax, *, momentum=0.9, epsilon=1e-5,
center=True, scale=True, gamma_initializer=None,
data_format='channels_last')
| 403 | 'decay': 'momentum' |
| 404 | }) |
| 405 | def BatchRenorm(x, rmax, dmax, *, momentum=0.9, epsilon=1e-5, |
| 406 | center=True, scale=True, gamma_initializer=None, |
| 407 | data_format='channels_last'): |
| 408 | """ |
| 409 | Batch Renormalization layer, as described in the paper: |
| 410 | `Batch Renormalization: Towards Reducing Minibatch Dependence in Batch-Normalized Models |
| 411 | <https://arxiv.org/abs/1702.03275>`_. |
| 412 | This implementation is a wrapper around `tf.layers.batch_normalization`. |
| 413 | |
| 414 | Args: |
| 415 | x (tf.Tensor): a NHWC or NC tensor. |
| 416 | rmax, dmax (tf.Tensor): a scalar tensor, the maximum allowed corrections. |
| 417 | decay (float): decay rate of moving average. |
| 418 | epsilon (float): epsilon to avoid divide-by-zero. |
| 419 | use_scale, use_bias (bool): whether to use the extra affine transformation or not. |
| 420 | |
| 421 | Returns: |
| 422 | tf.Tensor: a tensor named ``output`` with the same shape of x. |
| 423 | |
| 424 | Variable Names: |
| 425 | |
| 426 | * ``beta``: the bias term. |
| 427 | * ``gamma``: the scale term. Input will be transformed by ``x * gamma + beta``. |
| 428 | * ``moving_mean, renorm_mean, renorm_mean_weight``: See TF documentation. |
| 429 | * ``moving_variance, renorm_stddev, renorm_stddev_weight``: See TF documentation. |
| 430 | """ |
| 431 | |
| 432 | shape = x.get_shape().as_list() |
| 433 | ndims = len(shape) |
| 434 | assert ndims in [2, 4] |
| 435 | if ndims == 2: |
| 436 | data_format = 'channels_first' |
| 437 | |
| 438 | ctx = get_current_tower_context() |
| 439 | coll_bk = backup_collection([tf.GraphKeys.UPDATE_OPS]) |
| 440 | layer = tf.layers.BatchNormalization( |
| 441 | axis=1 if data_format == 'channels_first' else 3, |
| 442 | momentum=momentum, epsilon=epsilon, |
| 443 | center=center, scale=scale, |
| 444 | renorm=True, |
| 445 | renorm_clipping={ |
| 446 | 'rmin': 1.0 / rmax, |
| 447 | 'rmax': rmax, |
| 448 | 'dmax': dmax}, |
| 449 | renorm_momentum=0.99, |
| 450 | gamma_initializer=gamma_initializer, |
| 451 | fused=False, |
| 452 | _reuse=tf.get_variable_scope().reuse) |
| 453 | xn = layer.apply(x, training=ctx.is_training, scope=tf.get_variable_scope()) |
| 454 | |
| 455 | if ctx.is_main_training_tower: |
| 456 | for v in layer.non_trainable_variables: |
| 457 | if isinstance(v, tf.Variable): |
| 458 | tf.add_to_collection(tf.GraphKeys.MODEL_VARIABLES, v) |
| 459 | else: |
| 460 | # only run UPDATE_OPS in the first tower |
| 461 | restore_collection(coll_bk) |
| 462 |
nothing calls this directly
no test coverage detected