MCPcopy Index your code
hub / github.com/tensorpack/tensorpack / get_sync_bn_mean_var

Function get_sync_bn_mean_var

tensorpack/models/batch_norm.py:63–116  ·  view source on GitHub ↗
(inputs, red_axis, sync_statistics)

Source from the content-addressed store, hash-verified

61
62
63def get_sync_bn_mean_var(inputs, red_axis, sync_statistics):
64 ctx = get_current_tower_context()
65 batch_mean = tf.reduce_mean(inputs, axis=red_axis)
66 batch_mean_square = tf.reduce_mean(tf.square(inputs), axis=red_axis)
67
68 TF_version = get_tf_version_tuple()
69
70 if sync_statistics == 'nccl':
71 num_dev = ctx.total
72 if num_dev == 1:
73 logger.warn("BatchNorm(sync_statistics='nccl') is used with only one tower!")
74 else:
75 assert TF_version >= (1, 10), \
76 "Cross-GPU BatchNorm is only supported in TF>=1.10 ." \
77 "Upgrade TF or apply this patch manually: https://github.com/tensorflow/tensorflow/pull/20360"
78 if TF_version >= (1, 15):
79 logger.warn("BatchNorm(sync_statistics='nccl') may produce incorrect results due "
80 "to bug in TF>=1.15: https://github.com/tensorflow/tensorflow/issues/41539")
81
82 if TF_version <= (1, 12):
83 try:
84 from tensorflow.contrib.nccl.python.ops.nccl_ops import _validate_and_load_nccl_so # deprecated
85 except Exception:
86 pass
87 else:
88 _validate_and_load_nccl_so()
89 from tensorflow.contrib.nccl.ops import gen_nccl_ops # deprecated
90 else:
91 from tensorflow.python.ops import gen_nccl_ops
92 shared_name = re.sub('tower[0-9]+/', '', tf.get_variable_scope().name)
93 batch_mean = gen_nccl_ops.nccl_all_reduce(
94 input=batch_mean,
95 reduction='sum',
96 num_devices=num_dev,
97 shared_name=shared_name + '_NCCL_mean') * (1.0 / num_dev)
98 batch_mean_square = gen_nccl_ops.nccl_all_reduce(
99 input=batch_mean_square,
100 reduction='sum',
101 num_devices=num_dev,
102 shared_name=shared_name + '_NCCL_mean_square') * (1.0 / num_dev)
103 elif sync_statistics == 'horovod':
104 # Require https://github.com/uber/horovod/pull/331
105 import horovod.tensorflow as hvd
106 if hvd.size() == 1:
107 logger.warn("BatchNorm(sync_statistics='horovod') is used with only one process!")
108 else:
109 import horovod
110 hvd_version = tuple(map(int, horovod.__version__.split('.')[:3]))
111 assert hvd_version >= (0, 13, 6), "sync_statistics=horovod needs horovod>=0.13.6 !"
112
113 batch_mean = hvd.allreduce(batch_mean, average=True)
114 batch_mean_square = hvd.allreduce(batch_mean_square, average=True)
115 batch_var = batch_mean_square - tf.square(batch_mean)
116 return batch_mean, batch_var
117
118
119@layer_register()

Callers 1

BatchNormFunction · 0.85

Calls 4

get_tf_version_tupleFunction · 0.85
allreduceMethod · 0.80
sizeMethod · 0.45

Tested by

no test coverage detected