(timeout_second=36000)
| 16 | |
| 17 | |
| 18 | def initialize_global_process_group(timeout_second=36000): |
| 19 | import torch.distributed |
| 20 | from datetime import timedelta |
| 21 | torch.distributed.init_process_group('nccl', timeout=timedelta(seconds=timeout_second)) |
| 22 | local_rank = int(os.environ["LOCAL_RANK"]) |
| 23 | rank = int(os.environ["RANK"]) |
| 24 | world_size = int(os.environ["WORLD_SIZE"]) |
| 25 | |
| 26 | if torch.distributed.is_initialized(): |
| 27 | torch.cuda.set_device(local_rank) |
| 28 | return local_rank, rank, world_size |
no outgoing calls