MCPcopy Index your code
hub / github.com/tensorpack/tensorpack / allreduce_grads

Function allreduce_grads

tensorpack/graph_builder/utils.py:148–198  ·  view source on GitHub ↗

All-reduce average the gradients among K devices. Results are broadcasted to all devices. Args: all_grads (K x N): List of list of gradients. N is the number of variables. average (bool): average gradients or not. mode (str): "nccl", "collective" Returns:

(all_grads, average, mode="nccl")

Source from the content-addressed store, hash-verified

146
147@under_name_scope('AllReduceGrads')
148def allreduce_grads(all_grads, average, mode="nccl"):
149 """
150 All-reduce average the gradients among K devices. Results are broadcasted to all devices.
151
152 Args:
153 all_grads (K x N): List of list of gradients. N is the number of variables.
154 average (bool): average gradients or not.
155 mode (str): "nccl", "collective"
156
157 Returns:
158 K x N: same as input, but each grad is replaced by the average over K devices.
159 """
160 assert mode in ["nccl", "collective"], mode
161
162 nr_tower = len(all_grads)
163 if nr_tower == 1:
164 return all_grads
165 new_all_grads = [] # N x K
166 for grads in zip(*all_grads):
167 # k grads
168 if mode == "nccl":
169 if get_tf_version_tuple() <= (1, 12):
170 from tensorflow.contrib import nccl # deprecated
171 else:
172 from tensorflow.python.ops import nccl_ops as nccl
173 summed = nccl.all_sum(grads)
174 else:
175 from tensorflow.python.ops import collective_ops
176 summed = []
177 shared_cnt = _get_shared_cnt()
178 for t in grads:
179 with tf.device(t.device):
180 t = collective_ops.all_reduce(
181 t, len(grads),
182 42, # group key is any fixed integer for a fixed group of devices
183 shared_cnt + 100,
184 'Add', 'Id', communication_hint='nccl')
185 summed.append(t)
186
187 grads_for_devices = [] # K
188 for g in summed:
189 with tf.device(g.device):
190 # tensorflow/benchmarks didn't average gradients
191 if average:
192 g = tf.multiply(g, 1.0 / nr_tower)
193 grads_for_devices.append(g)
194 new_all_grads.append(grads_for_devices)
195
196 # transpose to K x N
197 ret = list(zip(*new_all_grads))
198 return ret
199
200
201@under_name_scope('AllReduceGradsHierachical')

Callers 1

do_allreduceMethod · 0.85

Calls 4

get_tf_version_tupleFunction · 0.85
_get_shared_cntFunction · 0.85
deviceMethod · 0.80
appendMethod · 0.80

Tested by

no test coverage detected