A more powerful version of `tf.layers.batch_normalization`. It differs from the offical one in the following aspects: 1. Accepts an alternative ``data_format`` option when ``axis`` is None. For 2D input, this argument will be ignored. 2. Default value for ``momentum`` and ``epsilon
(inputs, axis=None, *, training=None, momentum=0.9, epsilon=1e-5,
center=True, scale=True,
beta_initializer=tf.zeros_initializer(),
gamma_initializer=tf.ones_initializer(),
virtual_batch_size=None,
data_format='channels_last',
ema_update='default',
sync_statistics=None)
| 128 | }) |
| 129 | @disable_autograph() |
| 130 | def BatchNorm(inputs, axis=None, *, training=None, momentum=0.9, epsilon=1e-5, |
| 131 | center=True, scale=True, |
| 132 | beta_initializer=tf.zeros_initializer(), |
| 133 | gamma_initializer=tf.ones_initializer(), |
| 134 | virtual_batch_size=None, |
| 135 | data_format='channels_last', |
| 136 | ema_update='default', |
| 137 | sync_statistics=None): |
| 138 | """ |
| 139 | A more powerful version of `tf.layers.batch_normalization`. It differs from |
| 140 | the offical one in the following aspects: |
| 141 | |
| 142 | 1. Accepts an alternative ``data_format`` option when ``axis`` is None. For 2D input, this argument will be ignored. |
| 143 | 2. Default value for ``momentum`` and ``epsilon`` is different. |
| 144 | 3. Default value for ``training`` is automatically obtained from tensorpack's ``TowerContext``. |
| 145 | User-provided value can overwrite this behavior. |
| 146 | 4. Support the ``ema_update`` option, which covers broader use cases than the standard EMA update. |
| 147 | 5. Support the ``sync_statistics`` option, which implements "SyncBN" and is very useful in small-batch models. |
| 148 | 6. Better support of the ``virtual_batch_size`` option that does not have the bugs in ``tf.layers``. |
| 149 | |
| 150 | Args: |
| 151 | training (bool): if True, use per-batch statistics to normalize. Otherwise, use stored EMA |
| 152 | to normalize. By default, it is equal to `get_current_tower_context().is_training`. |
| 153 | This is not a good argument name, but it is what the Tensorflow layer uses. |
| 154 | virtual_batch_size (int): implement "Ghost BatchNorm" that normalizes |
| 155 | the data with a smaller batch size than the input. Only effective when training is True. |
| 156 | The value has to be a divisor of the actual batch size. |
| 157 | |
| 158 | It does not use the buggy TensorFlow implementation which has the |
| 159 | problems of (1) wrong behavior at inference; (2) create variables with unnecessary size=1 dimensions. |
| 160 | Corresponding TF issue: https://github.com/tensorflow/tensorflow/issues/23050 |
| 161 | ema_update (str): Only effective when ``training=True``. It has the following options: |
| 162 | |
| 163 | * "default": same as "collection". Because this is the default behavior in TensorFlow. |
| 164 | * "skip": do not update EMA. This can be useful when you reuse a batch norm layer in several places |
| 165 | but do not want them to all update your EMA. |
| 166 | * "collection": Add EMA update ops to collection `tf.GraphKeys.UPDATE_OPS` in the first training tower. |
| 167 | The ops in the collection will be run automatically by the callback :class:`RunUpdateOps`, along with |
| 168 | your training iterations. This can waste compute if your training iterations do not always depend |
| 169 | on the BatchNorm layer. |
| 170 | * "internal": EMA is updated in the first training tower inside this layer itself by control dependencies. |
| 171 | In standard scenarios, it has similar speed to "collection". But it supports more scenarios: |
| 172 | |
| 173 | 1. BatchNorm is used inside dynamic control flow. |
| 174 | The collection-based update does not support dynamic control flows. |
| 175 | 2. BatchNorm layer is sometimes unused (e.g., in GANs you have two networks to train alternatively). |
| 176 | Putting all update ops into a single collection will waste a lot of compute. |
| 177 | 3. Other part of the model relies on the "updated" EMA. The collection-based method does not update |
| 178 | EMA immediately. |
| 179 | 4. It has less chance to cause TensorFlow bugs in a graph with complicated control flow. |
| 180 | |
| 181 | Therefore this option is preferred over TensorFlow default. |
| 182 | Corresponding TF issue: https://github.com/tensorflow/tensorflow/issues/14699 |
| 183 | sync_statistics (str or None): one of None, "nccl", or "horovod". It determines how to compute the |
| 184 | "per-batch statistics" when ``training==True``. |
| 185 | |
| 186 | * None: it uses statistics of the input tensor to normalize during training. |
| 187 | This is the standard way BatchNorm was implemented in most frameworks. |
no test coverage detected