Applies batch normalization. Args: inputs: A tensor with 2 or more dimensions, where the first dimension has `batch_size`. If type is `bn`, the normalization is over all but the last dimension. Or if type is `ln`, the normalization is over the last dimension. Note
(inputs,
is_training=True,
activation_fn=None,
scope="bn",
reuse=None)
| 41 | |
| 42 | |
| 43 | def bn(inputs, |
| 44 | is_training=True, |
| 45 | activation_fn=None, |
| 46 | scope="bn", |
| 47 | reuse=None): |
| 48 | '''Applies batch normalization. |
| 49 | |
| 50 | Args: |
| 51 | inputs: A tensor with 2 or more dimensions, where the first dimension has |
| 52 | `batch_size`. If type is `bn`, the normalization is over all but |
| 53 | the last dimension. Or if type is `ln`, the normalization is over |
| 54 | the last dimension. Note that this is different from the native |
| 55 | `tf.contrib.layers.batch_norm`. For this I recommend you change |
| 56 | a line in ``tensorflow/contrib/layers/python/layers/layer.py` |
| 57 | as follows. |
| 58 | Before: mean, variance = nn.moments(inputs, axis, keep_dims=True) |
| 59 | After: mean, variance = nn.moments(inputs, [-1], keep_dims=True) |
| 60 | is_training: Whether or not the layer is in training mode. |
| 61 | activation_fn: Activation function. |
| 62 | scope: Optional scope for `variable_scope`. |
| 63 | reuse: Boolean, whether to reuse the weights of a previous layer |
| 64 | by the same name. |
| 65 | |
| 66 | Returns: |
| 67 | A tensor with the same shape and data dtype as `inputs`. |
| 68 | ''' |
| 69 | inputs_shape = inputs.get_shape() |
| 70 | inputs_rank = inputs_shape.ndims |
| 71 | |
| 72 | # use fused batch norm if inputs_rank in [2, 3, 4] as it is much faster. |
| 73 | # pay attention to the fact that fused_batch_norm requires shape to be rank 4 of NHWC. |
| 74 | if inputs_rank in [2, 3, 4]: |
| 75 | if inputs_rank == 2: |
| 76 | inputs = tf.expand_dims(inputs, axis=1) |
| 77 | inputs = tf.expand_dims(inputs, axis=2) |
| 78 | elif inputs_rank == 3: |
| 79 | inputs = tf.expand_dims(inputs, axis=1) |
| 80 | |
| 81 | outputs = tf.contrib.layers.batch_norm(inputs=inputs, |
| 82 | center=True, |
| 83 | scale=True, |
| 84 | updates_collections=None, |
| 85 | is_training=is_training, |
| 86 | scope=scope, |
| 87 | fused=True, |
| 88 | reuse=reuse) |
| 89 | # restore original shape |
| 90 | if inputs_rank == 2: |
| 91 | outputs = tf.squeeze(outputs, axis=[1, 2]) |
| 92 | elif inputs_rank == 3: |
| 93 | outputs = tf.squeeze(outputs, axis=1) |
| 94 | else: # fallback to naive batch norm |
| 95 | outputs = tf.contrib.layers.batch_norm(inputs=inputs, |
| 96 | center=True, |
| 97 | scale=True, |
| 98 | updates_collections=None, |
| 99 | is_training=is_training, |
| 100 | scope=scope, |
no outgoing calls
no test coverage detected