MCPcopy Index your code
hub / github.com/tensorpack/tensorpack / InstanceNorm

Function InstanceNorm

tensorpack/models/layer_norm.py:80–131  ·  view source on GitHub ↗

Instance Normalization, as in the paper: `Instance Normalization: The Missing Ingredient for Fast Stylization `_. Args: x (tf.Tensor): a 4D tensor. epsilon (float): avoid divide-by-zero center, scale (bool): whether to use t

(x, epsilon=1e-5, *, center=True, scale=True,
                 gamma_initializer=tf.ones_initializer(),
                 data_format='channels_last', use_affine=None)

Source from the content-addressed store, hash-verified

78 'gamma_init': 'gamma_initializer',
79 })
80def InstanceNorm(x, epsilon=1e-5, *, center=True, scale=True,
81 gamma_initializer=tf.ones_initializer(),
82 data_format='channels_last', use_affine=None):
83 """
84 Instance Normalization, as in the paper:
85 `Instance Normalization: The Missing Ingredient for Fast Stylization
86 <https://arxiv.org/abs/1607.08022>`_.
87
88 Args:
89 x (tf.Tensor): a 4D tensor.
90 epsilon (float): avoid divide-by-zero
91 center, scale (bool): whether to use the extra affine transformation or not.
92 use_affine: deprecated. Don&#x27;t use.
93 """
94 data_format = get_data_format(data_format, keras_mode=False)
95 shape = x.get_shape().as_list()
96 assert len(shape) == 4, "Input of InstanceNorm has to be 4D!"
97
98 if use_affine is not None:
99 log_deprecated("InstanceNorm(use_affine=)", "Use center= or scale= instead!", "2020-06-01")
100 center = scale = use_affine
101
102 if data_format == 'NHWC':
103 axis = [1, 2]
104 ch = shape[3]
105 new_shape = [1, 1, 1, ch]
106 else:
107 axis = [2, 3]
108 ch = shape[1]
109 new_shape = [1, ch, 1, 1]
110 assert ch is not None, "Input of InstanceNorm require known channel!"
111
112 mean, var = tf.nn.moments(x, axis, keep_dims=True)
113
114 if center:
115 beta = tf.get_variable('beta', [ch], initializer=tf.constant_initializer())
116 beta = tf.reshape(beta, new_shape)
117 else:
118 beta = tf.zeros([1, 1, 1, 1], name='beta', dtype=x.dtype)
119 if scale:
120 gamma = tf.get_variable('gamma', [ch], initializer=gamma_initializer)
121 gamma = tf.reshape(gamma, new_shape)
122 else:
123 gamma = tf.ones([1, 1, 1, 1], name='gamma', dtype=x.dtype)
124 ret = tf.nn.batch_normalization(x, mean, var, beta, gamma, epsilon, name='output')
125
126 vh = ret.variables = VariableHolder()
127 if scale:
128 vh.gamma = gamma
129 if center:
130 vh.beta = beta
131 return ret

Callers 2

INReLUFunction · 0.85
INLReLUFunction · 0.85

Calls 4

get_data_formatFunction · 0.85
log_deprecatedFunction · 0.85
VariableHolderClass · 0.85
get_variableMethod · 0.80

Tested by

no test coverage detected