Instance Normalization, as in the paper: `Instance Normalization: The Missing Ingredient for Fast Stylization `_. Args: x (tf.Tensor): a 4D tensor. epsilon (float): avoid divide-by-zero center, scale (bool): whether to use t
(x, epsilon=1e-5, *, center=True, scale=True,
gamma_initializer=tf.ones_initializer(),
data_format='channels_last', use_affine=None)
| 78 | 'gamma_init': 'gamma_initializer', |
| 79 | }) |
| 80 | def InstanceNorm(x, epsilon=1e-5, *, center=True, scale=True, |
| 81 | gamma_initializer=tf.ones_initializer(), |
| 82 | data_format='channels_last', use_affine=None): |
| 83 | """ |
| 84 | Instance Normalization, as in the paper: |
| 85 | `Instance Normalization: The Missing Ingredient for Fast Stylization |
| 86 | <https://arxiv.org/abs/1607.08022>`_. |
| 87 | |
| 88 | Args: |
| 89 | x (tf.Tensor): a 4D tensor. |
| 90 | epsilon (float): avoid divide-by-zero |
| 91 | center, scale (bool): whether to use the extra affine transformation or not. |
| 92 | use_affine: deprecated. Don't use. |
| 93 | """ |
| 94 | data_format = get_data_format(data_format, keras_mode=False) |
| 95 | shape = x.get_shape().as_list() |
| 96 | assert len(shape) == 4, "Input of InstanceNorm has to be 4D!" |
| 97 | |
| 98 | if use_affine is not None: |
| 99 | log_deprecated("InstanceNorm(use_affine=)", "Use center= or scale= instead!", "2020-06-01") |
| 100 | center = scale = use_affine |
| 101 | |
| 102 | if data_format == 'NHWC': |
| 103 | axis = [1, 2] |
| 104 | ch = shape[3] |
| 105 | new_shape = [1, 1, 1, ch] |
| 106 | else: |
| 107 | axis = [2, 3] |
| 108 | ch = shape[1] |
| 109 | new_shape = [1, ch, 1, 1] |
| 110 | assert ch is not None, "Input of InstanceNorm require known channel!" |
| 111 | |
| 112 | mean, var = tf.nn.moments(x, axis, keep_dims=True) |
| 113 | |
| 114 | if center: |
| 115 | beta = tf.get_variable('beta', [ch], initializer=tf.constant_initializer()) |
| 116 | beta = tf.reshape(beta, new_shape) |
| 117 | else: |
| 118 | beta = tf.zeros([1, 1, 1, 1], name='beta', dtype=x.dtype) |
| 119 | if scale: |
| 120 | gamma = tf.get_variable('gamma', [ch], initializer=gamma_initializer) |
| 121 | gamma = tf.reshape(gamma, new_shape) |
| 122 | else: |
| 123 | gamma = tf.ones([1, 1, 1, 1], name='gamma', dtype=x.dtype) |
| 124 | ret = tf.nn.batch_normalization(x, mean, var, beta, gamma, epsilon, name='output') |
| 125 | |
| 126 | vh = ret.variables = VariableHolder() |
| 127 | if scale: |
| 128 | vh.gamma = gamma |
| 129 | if center: |
| 130 | vh.beta = beta |
| 131 | return ret |
no test coverage detected