Returns: list[int]: size of the tensor, on each rank Tensor: padded tensor that has the max size
(tensor, group)
| 124 | |
| 125 | |
| 126 | def _pad_to_largest_tensor(tensor, group): |
| 127 | """ |
| 128 | Returns: |
| 129 | list[int]: size of the tensor, on each rank |
| 130 | Tensor: padded tensor that has the max size |
| 131 | """ |
| 132 | world_size = dist.get_world_size(group=group) |
| 133 | assert ( |
| 134 | world_size >= 1 |
| 135 | ), "comm.gather/all_gather must be called from ranks within the given group!" |
| 136 | local_size = torch.tensor([tensor.numel()], dtype=torch.int64, device=tensor.device) |
| 137 | size_list = [ |
| 138 | torch.zeros([1], dtype=torch.int64, device=tensor.device) |
| 139 | for _ in range(world_size) |
| 140 | ] |
| 141 | dist.all_gather(size_list, local_size, group=group) |
| 142 | size_list = [int(size.item()) for size in size_list] |
| 143 | |
| 144 | max_size = max(size_list) |
| 145 | |
| 146 | # we pad the tensor because torch all_gather does not support |
| 147 | # gathering tensors of different shapes |
| 148 | if local_size != max_size: |
| 149 | padding = torch.zeros( |
| 150 | (max_size - local_size,), dtype=torch.uint8, device=tensor.device |
| 151 | ) |
| 152 | tensor = torch.cat((tensor, padding), dim=0) |
| 153 | return size_list, tensor |
| 154 | |
| 155 | |
| 156 | def all_gather(data, group=None): |
no outgoing calls
no test coverage detected