Apply all reduce function for python dict object. NOTE: make sure that every py_dict has the same keys and values are in the same shape. Args: py_dict (dict): dict to apply all reduce op. op (str): operator, could be "sum" or "mean".
(py_dict, op="sum", group=None)
| 57 | |
| 58 | |
| 59 | def all_reduce(py_dict, op="sum", group=None): |
| 60 | """ |
| 61 | Apply all reduce function for python dict object. |
| 62 | NOTE: make sure that every py_dict has the same keys and values are in the same shape. |
| 63 | |
| 64 | Args: |
| 65 | py_dict (dict): dict to apply all reduce op. |
| 66 | op (str): operator, could be "sum" or "mean". |
| 67 | """ |
| 68 | world_size = get_world_size() |
| 69 | if world_size == 1: |
| 70 | return py_dict |
| 71 | if group is None: |
| 72 | group = _get_global_gloo_group() |
| 73 | if dist.get_world_size(group) == 1: |
| 74 | return py_dict |
| 75 | |
| 76 | # all reduce logic across different devices. |
| 77 | py_key = list(py_dict.keys()) |
| 78 | py_key_tensor = pyobj2tensor(py_key) |
| 79 | dist.broadcast(py_key_tensor, src=0) |
| 80 | py_key = tensor2pyobj(py_key_tensor) |
| 81 | |
| 82 | tensor_shapes = [py_dict[k].shape for k in py_key] |
| 83 | tensor_numels = [py_dict[k].numel() for k in py_key] |
| 84 | |
| 85 | flatten_tensor = torch.cat([py_dict[k].flatten() for k in py_key]) |
| 86 | dist.all_reduce(flatten_tensor, op=_get_reduce_op(op)) |
| 87 | if op == "mean": |
| 88 | flatten_tensor /= world_size |
| 89 | |
| 90 | split_tensors = [ |
| 91 | x.reshape(shape) |
| 92 | for x, shape in zip(torch.split(flatten_tensor, tensor_numels), tensor_shapes) |
| 93 | ] |
| 94 | return OrderedDict({k: v for k, v in zip(py_key, split_tensors)}) |
| 95 | |
| 96 | |
| 97 | def all_reduce_norm(module): |
no test coverage detected