(tensor_list, tensor, **kwargs)
| 231 | raise ValueError('Not support broadcast for torchacc yet!') |
| 232 | |
| 233 | def torcacc_all_gather(tensor_list, tensor, **kwargs): |
| 234 | if len(tensor.size()) == 0: |
| 235 | raise ValueError( |
| 236 | 'Not support ``all_gather`` scaler type for torchacc!') |
| 237 | |
| 238 | res = xm.all_gather(value=tensor, dim=0, **kwargs) |
| 239 | splits = torch.tensor_split(res, len(tensor_list)) |
| 240 | |
| 241 | for i in range(len(tensor_list)): |
| 242 | assert splits[i].size() == tensor.size( |
| 243 | ), 'mismatch size: {}, {}'.format(splits[i].size(), tensor.size()) |
| 244 | tensor_list[i] = splits[i] |
| 245 | del splits |
| 246 | |
| 247 | collector.add_op( |
| 248 | 'TO', OpSpec(module=None, name=None, |
nothing calls this directly
no test coverage detected