(inputs, red_axis, sync_statistics)
| 61 | |
| 62 | |
| 63 | def get_sync_bn_mean_var(inputs, red_axis, sync_statistics): |
| 64 | ctx = get_current_tower_context() |
| 65 | batch_mean = tf.reduce_mean(inputs, axis=red_axis) |
| 66 | batch_mean_square = tf.reduce_mean(tf.square(inputs), axis=red_axis) |
| 67 | |
| 68 | TF_version = get_tf_version_tuple() |
| 69 | |
| 70 | if sync_statistics == 'nccl': |
| 71 | num_dev = ctx.total |
| 72 | if num_dev == 1: |
| 73 | logger.warn("BatchNorm(sync_statistics='nccl') is used with only one tower!") |
| 74 | else: |
| 75 | assert TF_version >= (1, 10), \ |
| 76 | "Cross-GPU BatchNorm is only supported in TF>=1.10 ." \ |
| 77 | "Upgrade TF or apply this patch manually: https://github.com/tensorflow/tensorflow/pull/20360" |
| 78 | if TF_version >= (1, 15): |
| 79 | logger.warn("BatchNorm(sync_statistics='nccl') may produce incorrect results due " |
| 80 | "to bug in TF>=1.15: https://github.com/tensorflow/tensorflow/issues/41539") |
| 81 | |
| 82 | if TF_version <= (1, 12): |
| 83 | try: |
| 84 | from tensorflow.contrib.nccl.python.ops.nccl_ops import _validate_and_load_nccl_so # deprecated |
| 85 | except Exception: |
| 86 | pass |
| 87 | else: |
| 88 | _validate_and_load_nccl_so() |
| 89 | from tensorflow.contrib.nccl.ops import gen_nccl_ops # deprecated |
| 90 | else: |
| 91 | from tensorflow.python.ops import gen_nccl_ops |
| 92 | shared_name = re.sub('tower[0-9]+/', '', tf.get_variable_scope().name) |
| 93 | batch_mean = gen_nccl_ops.nccl_all_reduce( |
| 94 | input=batch_mean, |
| 95 | reduction='sum', |
| 96 | num_devices=num_dev, |
| 97 | shared_name=shared_name + '_NCCL_mean') * (1.0 / num_dev) |
| 98 | batch_mean_square = gen_nccl_ops.nccl_all_reduce( |
| 99 | input=batch_mean_square, |
| 100 | reduction='sum', |
| 101 | num_devices=num_dev, |
| 102 | shared_name=shared_name + '_NCCL_mean_square') * (1.0 / num_dev) |
| 103 | elif sync_statistics == 'horovod': |
| 104 | # Require https://github.com/uber/horovod/pull/331 |
| 105 | import horovod.tensorflow as hvd |
| 106 | if hvd.size() == 1: |
| 107 | logger.warn("BatchNorm(sync_statistics='horovod') is used with only one process!") |
| 108 | else: |
| 109 | import horovod |
| 110 | hvd_version = tuple(map(int, horovod.__version__.split('.')[:3])) |
| 111 | assert hvd_version >= (0, 13, 6), "sync_statistics=horovod needs horovod>=0.13.6 !" |
| 112 | |
| 113 | batch_mean = hvd.allreduce(batch_mean, average=True) |
| 114 | batch_mean_square = hvd.allreduce(batch_mean_square, average=True) |
| 115 | batch_var = batch_mean_square - tf.square(batch_mean) |
| 116 | return batch_mean, batch_var |
| 117 | |
| 118 | |
| 119 | @layer_register() |
no test coverage detected