All-reduce average the gradients among K devices. Results are broadcasted to all devices. Args: all_grads (K x N): List of list of gradients. N is the number of variables. average (bool): average gradients or not. mode (str): "nccl", "collective" Returns:
(all_grads, average, mode="nccl")
| 146 | |
| 147 | @under_name_scope('AllReduceGrads') |
| 148 | def allreduce_grads(all_grads, average, mode="nccl"): |
| 149 | """ |
| 150 | All-reduce average the gradients among K devices. Results are broadcasted to all devices. |
| 151 | |
| 152 | Args: |
| 153 | all_grads (K x N): List of list of gradients. N is the number of variables. |
| 154 | average (bool): average gradients or not. |
| 155 | mode (str): "nccl", "collective" |
| 156 | |
| 157 | Returns: |
| 158 | K x N: same as input, but each grad is replaced by the average over K devices. |
| 159 | """ |
| 160 | assert mode in ["nccl", "collective"], mode |
| 161 | |
| 162 | nr_tower = len(all_grads) |
| 163 | if nr_tower == 1: |
| 164 | return all_grads |
| 165 | new_all_grads = [] # N x K |
| 166 | for grads in zip(*all_grads): |
| 167 | # k grads |
| 168 | if mode == "nccl": |
| 169 | if get_tf_version_tuple() <= (1, 12): |
| 170 | from tensorflow.contrib import nccl # deprecated |
| 171 | else: |
| 172 | from tensorflow.python.ops import nccl_ops as nccl |
| 173 | summed = nccl.all_sum(grads) |
| 174 | else: |
| 175 | from tensorflow.python.ops import collective_ops |
| 176 | summed = [] |
| 177 | shared_cnt = _get_shared_cnt() |
| 178 | for t in grads: |
| 179 | with tf.device(t.device): |
| 180 | t = collective_ops.all_reduce( |
| 181 | t, len(grads), |
| 182 | 42, # group key is any fixed integer for a fixed group of devices |
| 183 | shared_cnt + 100, |
| 184 | 'Add', 'Id', communication_hint='nccl') |
| 185 | summed.append(t) |
| 186 | |
| 187 | grads_for_devices = [] # K |
| 188 | for g in summed: |
| 189 | with tf.device(g.device): |
| 190 | # tensorflow/benchmarks didn't average gradients |
| 191 | if average: |
| 192 | g = tf.multiply(g, 1.0 / nr_tower) |
| 193 | grads_for_devices.append(g) |
| 194 | new_all_grads.append(grads_for_devices) |
| 195 | |
| 196 | # transpose to K x N |
| 197 | ret = list(zip(*new_all_grads)) |
| 198 | return ret |
| 199 | |
| 200 | |
| 201 | @under_name_scope('AllReduceGradsHierachical') |
no test coverage detected