MCPcopy Index your code
hub / github.com/tensorpack/tensorpack / LayerNorm

Function LayerNorm

tensorpack/models/layer_norm.py:23–71  ·  view source on GitHub ↗

Layer Normalization layer, as described in the paper: `Layer Normalization `_. Args: x (tf.Tensor): a 4D or 2D tensor. When 4D, the layout should match data_format. epsilon (float): epsilon to avoid divide-by-zero. center, scale

(
        x, epsilon=1e-5, *,
        center=True, scale=True,
        gamma_initializer=tf.ones_initializer(),
        data_format='channels_last')

Source from the content-addressed store, hash-verified

21 'gamma_init': 'gamma_initializer',
22 })
23def LayerNorm(
24 x, epsilon=1e-5, *,
25 center=True, scale=True,
26 gamma_initializer=tf.ones_initializer(),
27 data_format='channels_last'):
28 """
29 Layer Normalization layer, as described in the paper:
30 `Layer Normalization <https://arxiv.org/abs/1607.06450>`_.
31
32 Args:
33 x (tf.Tensor): a 4D or 2D tensor. When 4D, the layout should match data_format.
34 epsilon (float): epsilon to avoid divide-by-zero.
35 center, scale (bool): whether to use the extra affine transformation or not.
36 """
37 data_format = get_data_format(data_format, keras_mode=False)
38 shape = x.get_shape().as_list()
39 ndims = len(shape)
40 assert ndims in [2, 4]
41
42 mean, var = tf.nn.moments(x, list(range(1, len(shape))), keep_dims=True)
43
44 if data_format == 'NCHW':
45 chan = shape[1]
46 new_shape = [1, chan, 1, 1]
47 else:
48 chan = shape[-1]
49 new_shape = [1, 1, 1, chan]
50 if ndims == 2:
51 new_shape = [1, chan]
52
53 if center:
54 beta = tf.get_variable('beta', [chan], initializer=tf.constant_initializer())
55 beta = tf.reshape(beta, new_shape)
56 else:
57 beta = tf.zeros([1] * ndims, name='beta')
58 if scale:
59 gamma = tf.get_variable('gamma', [chan], initializer=gamma_initializer)
60 gamma = tf.reshape(gamma, new_shape)
61 else:
62 gamma = tf.ones([1] * ndims, name='gamma')
63
64 ret = tf.nn.batch_normalization(x, mean, var, beta, gamma, epsilon, name='output')
65
66 vh = ret.variables = VariableHolder()
67 if scale:
68 vh.gamma = gamma
69 if center:
70 vh.beta = beta
71 return ret
72
73
74@layer_register()

Callers

nothing calls this directly

Calls 3

get_data_formatFunction · 0.85
VariableHolderClass · 0.85
get_variableMethod · 0.80

Tested by

no test coverage detected