MCPcopy Index your code
hub / github.com/FoundationVision/ByteTrack / _pad_to_largest_tensor

Function _pad_to_largest_tensor

yolox/utils/dist.py:126–153  ·  view source on GitHub ↗

Returns: list[int]: size of the tensor, on each rank Tensor: padded tensor that has the max size

(tensor, group)

Source from the content-addressed store, hash-verified

124
125
126def _pad_to_largest_tensor(tensor, group):
127 """
128 Returns:
129 list[int]: size of the tensor, on each rank
130 Tensor: padded tensor that has the max size
131 """
132 world_size = dist.get_world_size(group=group)
133 assert (
134 world_size >= 1
135 ), "comm.gather/all_gather must be called from ranks within the given group!"
136 local_size = torch.tensor([tensor.numel()], dtype=torch.int64, device=tensor.device)
137 size_list = [
138 torch.zeros([1], dtype=torch.int64, device=tensor.device)
139 for _ in range(world_size)
140 ]
141 dist.all_gather(size_list, local_size, group=group)
142 size_list = [int(size.item()) for size in size_list]
143
144 max_size = max(size_list)
145
146 # we pad the tensor because torch all_gather does not support
147 # gathering tensors of different shapes
148 if local_size != max_size:
149 padding = torch.zeros(
150 (max_size - local_size,), dtype=torch.uint8, device=tensor.device
151 )
152 tensor = torch.cat((tensor, padding), dim=0)
153 return size_list, tensor
154
155
156def all_gather(data, group=None):

Callers 2

all_gatherFunction · 0.85
gatherFunction · 0.85

Calls

no outgoing calls

Tested by

no test coverage detected