Layer Normalization layer, as described in the paper: `Layer Normalization `_. Args: x (tf.Tensor): a 4D or 2D tensor. When 4D, the layout should match data_format. epsilon (float): epsilon to avoid divide-by-zero. center, scale
(
x, epsilon=1e-5, *,
center=True, scale=True,
gamma_initializer=tf.ones_initializer(),
data_format='channels_last')
| 21 | 'gamma_init': 'gamma_initializer', |
| 22 | }) |
| 23 | def LayerNorm( |
| 24 | x, epsilon=1e-5, *, |
| 25 | center=True, scale=True, |
| 26 | gamma_initializer=tf.ones_initializer(), |
| 27 | data_format='channels_last'): |
| 28 | """ |
| 29 | Layer Normalization layer, as described in the paper: |
| 30 | `Layer Normalization <https://arxiv.org/abs/1607.06450>`_. |
| 31 | |
| 32 | Args: |
| 33 | x (tf.Tensor): a 4D or 2D tensor. When 4D, the layout should match data_format. |
| 34 | epsilon (float): epsilon to avoid divide-by-zero. |
| 35 | center, scale (bool): whether to use the extra affine transformation or not. |
| 36 | """ |
| 37 | data_format = get_data_format(data_format, keras_mode=False) |
| 38 | shape = x.get_shape().as_list() |
| 39 | ndims = len(shape) |
| 40 | assert ndims in [2, 4] |
| 41 | |
| 42 | mean, var = tf.nn.moments(x, list(range(1, len(shape))), keep_dims=True) |
| 43 | |
| 44 | if data_format == 'NCHW': |
| 45 | chan = shape[1] |
| 46 | new_shape = [1, chan, 1, 1] |
| 47 | else: |
| 48 | chan = shape[-1] |
| 49 | new_shape = [1, 1, 1, chan] |
| 50 | if ndims == 2: |
| 51 | new_shape = [1, chan] |
| 52 | |
| 53 | if center: |
| 54 | beta = tf.get_variable('beta', [chan], initializer=tf.constant_initializer()) |
| 55 | beta = tf.reshape(beta, new_shape) |
| 56 | else: |
| 57 | beta = tf.zeros([1] * ndims, name='beta') |
| 58 | if scale: |
| 59 | gamma = tf.get_variable('gamma', [chan], initializer=gamma_initializer) |
| 60 | gamma = tf.reshape(gamma, new_shape) |
| 61 | else: |
| 62 | gamma = tf.ones([1] * ndims, name='gamma') |
| 63 | |
| 64 | ret = tf.nn.batch_normalization(x, mean, var, beta, gamma, epsilon, name='output') |
| 65 | |
| 66 | vh = ret.variables = VariableHolder() |
| 67 | if scale: |
| 68 | vh.gamma = gamma |
| 69 | if center: |
| 70 | vh.beta = beta |
| 71 | return ret |
| 72 | |
| 73 | |
| 74 | @layer_register() |
nothing calls this directly
no test coverage detected