Run gather on arbitrary picklable data (not necessarily tensors). Args: data: any picklable object dst (int): destination rank group: a torch process group. By default, will use a group which contains all ranks on gloo backend. Returns: list
(data, dst=0, group=None)
| 192 | |
| 193 | |
| 194 | def gather(data, dst=0, group=None): |
| 195 | """ |
| 196 | Run gather on arbitrary picklable data (not necessarily tensors). |
| 197 | |
| 198 | Args: |
| 199 | data: any picklable object |
| 200 | dst (int): destination rank |
| 201 | group: a torch process group. By default, will use a group which |
| 202 | contains all ranks on gloo backend. |
| 203 | |
| 204 | Returns: |
| 205 | list[data]: on dst, a list of data gathered from each rank. Otherwise, |
| 206 | an empty list. |
| 207 | """ |
| 208 | if get_world_size() == 1: |
| 209 | return [data] |
| 210 | if group is None: |
| 211 | group = _get_global_gloo_group() |
| 212 | if dist.get_world_size(group=group) == 1: |
| 213 | return [data] |
| 214 | rank = dist.get_rank(group=group) |
| 215 | |
| 216 | tensor = _serialize_to_tensor(data, group) |
| 217 | size_list, tensor = _pad_to_largest_tensor(tensor, group) |
| 218 | |
| 219 | # receiving Tensor from all ranks |
| 220 | if rank == dst: |
| 221 | max_size = max(size_list) |
| 222 | tensor_list = [ |
| 223 | torch.empty((max_size,), dtype=torch.uint8, device=tensor.device) |
| 224 | for _ in size_list |
| 225 | ] |
| 226 | dist.gather(tensor, tensor_list, dst=dst, group=group) |
| 227 | |
| 228 | data_list = [] |
| 229 | for size, tensor in zip(size_list, tensor_list): |
| 230 | buffer = tensor.cpu().numpy().tobytes()[:size] |
| 231 | data_list.append(pickle.loads(buffer)) |
| 232 | return data_list |
| 233 | else: |
| 234 | dist.gather(tensor, [], dst=dst, group=group) |
| 235 | return [] |
| 236 | |
| 237 | |
| 238 | def shared_random_seed(): |
no test coverage detected