MCPcopy Index your code
hub / github.com/FoundationVision/ByteTrack / all_reduce

Function all_reduce

yolox/utils/allreduce_norm.py:59–94  ·  view source on GitHub ↗

Apply all reduce function for python dict object. NOTE: make sure that every py_dict has the same keys and values are in the same shape. Args: py_dict (dict): dict to apply all reduce op. op (str): operator, could be "sum" or "mean".

(py_dict, op="sum", group=None)

Source from the content-addressed store, hash-verified

57
58
59def all_reduce(py_dict, op="sum", group=None):
60 """
61 Apply all reduce function for python dict object.
62 NOTE: make sure that every py_dict has the same keys and values are in the same shape.
63
64 Args:
65 py_dict (dict): dict to apply all reduce op.
66 op (str): operator, could be "sum" or "mean".
67 """
68 world_size = get_world_size()
69 if world_size == 1:
70 return py_dict
71 if group is None:
72 group = _get_global_gloo_group()
73 if dist.get_world_size(group) == 1:
74 return py_dict
75
76 # all reduce logic across different devices.
77 py_key = list(py_dict.keys())
78 py_key_tensor = pyobj2tensor(py_key)
79 dist.broadcast(py_key_tensor, src=0)
80 py_key = tensor2pyobj(py_key_tensor)
81
82 tensor_shapes = [py_dict[k].shape for k in py_key]
83 tensor_numels = [py_dict[k].numel() for k in py_key]
84
85 flatten_tensor = torch.cat([py_dict[k].flatten() for k in py_key])
86 dist.all_reduce(flatten_tensor, op=_get_reduce_op(op))
87 if op == "mean":
88 flatten_tensor /= world_size
89
90 split_tensors = [
91 x.reshape(shape)
92 for x, shape in zip(torch.split(flatten_tensor, tensor_numels), tensor_shapes)
93 ]
94 return OrderedDict({k: v for k, v in zip(py_key, split_tensors)})
95
96
97def all_reduce_norm(module):

Callers 1

all_reduce_normFunction · 0.85

Calls 5

get_world_sizeFunction · 0.85
_get_global_gloo_groupFunction · 0.85
pyobj2tensorFunction · 0.85
tensor2pyobjFunction · 0.85
_get_reduce_opFunction · 0.85

Tested by

no test coverage detected