(flag: bool, op=all)
| 570 | |
| 571 | |
| 572 | def reduce_boolean_flags(flag: bool, op=all) -> bool: |
| 573 | if not dist.is_initialized(): |
| 574 | return flag |
| 575 | device = get_accelerator().current_device() |
| 576 | tensor_flag = torch.tensor(1 if flag else 0, dtype=torch.int, device=device) |
| 577 | world_size = dist.get_world_size() |
| 578 | tensor_flag_buf = torch.zeros(world_size, dtype=torch.int, device=device) |
| 579 | dist.all_gather_into_tensor(tensor_flag_buf, tensor_flag) |
| 580 | list_flags = [bool(f) for f in tensor_flag_buf.tolist()] |
| 581 | return op(list_flags) |
| 582 | |
| 583 | |
| 584 | def allclose_on_all_ranks(actual, expected, assert_message=None, **kwargs) -> None: |
searching dependent graphs…