MCPcopy Index your code
hub / github.com/tensorpack/tensorpack / BatchRenorm

Function BatchRenorm

tensorpack/models/batch_norm.py:405–473  ·  view source on GitHub ↗

Batch Renormalization layer, as described in the paper: `Batch Renormalization: Towards Reducing Minibatch Dependence in Batch-Normalized Models `_. This implementation is a wrapper around `tf.layers.batch_normalization`. Args: x (tf.Te

(x, rmax, dmax, *, momentum=0.9, epsilon=1e-5,
                center=True, scale=True, gamma_initializer=None,
                data_format='channels_last')

Source from the content-addressed store, hash-verified

403 'decay': 'momentum'
404 })
405def BatchRenorm(x, rmax, dmax, *, momentum=0.9, epsilon=1e-5,
406 center=True, scale=True, gamma_initializer=None,
407 data_format='channels_last'):
408 """
409 Batch Renormalization layer, as described in the paper:
410 `Batch Renormalization: Towards Reducing Minibatch Dependence in Batch-Normalized Models
411 <https://arxiv.org/abs/1702.03275>`_.
412 This implementation is a wrapper around `tf.layers.batch_normalization`.
413
414 Args:
415 x (tf.Tensor): a NHWC or NC tensor.
416 rmax, dmax (tf.Tensor): a scalar tensor, the maximum allowed corrections.
417 decay (float): decay rate of moving average.
418 epsilon (float): epsilon to avoid divide-by-zero.
419 use_scale, use_bias (bool): whether to use the extra affine transformation or not.
420
421 Returns:
422 tf.Tensor: a tensor named ``output`` with the same shape of x.
423
424 Variable Names:
425
426 * ``beta``: the bias term.
427 * ``gamma``: the scale term. Input will be transformed by ``x * gamma + beta``.
428 * ``moving_mean, renorm_mean, renorm_mean_weight``: See TF documentation.
429 * ``moving_variance, renorm_stddev, renorm_stddev_weight``: See TF documentation.
430 """
431
432 shape = x.get_shape().as_list()
433 ndims = len(shape)
434 assert ndims in [2, 4]
435 if ndims == 2:
436 data_format = 'channels_first'
437
438 ctx = get_current_tower_context()
439 coll_bk = backup_collection([tf.GraphKeys.UPDATE_OPS])
440 layer = tf.layers.BatchNormalization(
441 axis=1 if data_format == 'channels_first' else 3,
442 momentum=momentum, epsilon=epsilon,
443 center=center, scale=scale,
444 renorm=True,
445 renorm_clipping={
446 'rmin': 1.0 / rmax,
447 'rmax': rmax,
448 'dmax': dmax},
449 renorm_momentum=0.99,
450 gamma_initializer=gamma_initializer,
451 fused=False,
452 _reuse=tf.get_variable_scope().reuse)
453 xn = layer.apply(x, training=ctx.is_training, scope=tf.get_variable_scope())
454
455 if ctx.is_main_training_tower:
456 for v in layer.non_trainable_variables:
457 if isinstance(v, tf.Variable):
458 tf.add_to_collection(tf.GraphKeys.MODEL_VARIABLES, v)
459 else:
460 # only run UPDATE_OPS in the first tower
461 restore_collection(coll_bk)
462

Callers

nothing calls this directly

Calls 5

backup_collectionFunction · 0.85
restore_collectionFunction · 0.85
VariableHolderClass · 0.85
applyMethod · 0.45

Tested by

no test coverage detected