(args)
| 68 | |
| 69 | |
| 70 | def init_distributed_device(args): |
| 71 | # Distributed training = training on more than one GPU. |
| 72 | # Works in both single and multi-node scenarios. |
| 73 | args.distributed = False |
| 74 | args.world_size = 1 |
| 75 | args.rank = 0 # global rank |
| 76 | args.local_rank = 0 |
| 77 | if args.horovod: |
| 78 | assert hvd is not None, "Horovod is not installed" |
| 79 | hvd.init() |
| 80 | world_size = int(os.environ["OMPI_COMM_WORLD_SIZE"]) |
| 81 | world_rank = int(os.environ["OMPI_COMM_WORLD_RANK"]) |
| 82 | local_rank = int(os.environ["OMPI_COMM_WORLD_LOCAL_RANK"]) |
| 83 | args.local_rank = local_rank |
| 84 | args.rank = world_rank |
| 85 | args.world_size = world_size |
| 86 | # args.local_rank = int(hvd.local_rank()) |
| 87 | # args.rank = hvd.rank() |
| 88 | # args.world_size = hvd.size() |
| 89 | args.distributed = True |
| 90 | os.environ["LOCAL_RANK"] = str(args.local_rank) |
| 91 | os.environ["RANK"] = str(args.rank) |
| 92 | os.environ["WORLD_SIZE"] = str(args.world_size) |
| 93 | print( |
| 94 | f"Distributed training: local_rank={args.local_rank}, " |
| 95 | f"rank={args.rank}, world_size={args.world_size}, " |
| 96 | f"hostname={socket.gethostname()}, pid={os.getpid()}" |
| 97 | ) |
| 98 | elif is_using_distributed(): |
| 99 | if "SLURM_PROCID" in os.environ: |
| 100 | # DDP via SLURM |
| 101 | args.local_rank, args.rank, args.world_size = world_info_from_env() |
| 102 | # SLURM var -> torch.distributed vars in case needed |
| 103 | os.environ["LOCAL_RANK"] = str(args.local_rank) |
| 104 | os.environ["RANK"] = str(args.rank) |
| 105 | os.environ["WORLD_SIZE"] = str(args.world_size) |
| 106 | torch.distributed.init_process_group( |
| 107 | backend=args.dist_backend, |
| 108 | init_method=args.dist_url, |
| 109 | world_size=args.world_size, |
| 110 | rank=args.rank, |
| 111 | ) |
| 112 | elif "OMPI_COMM_WORLD_SIZE" in os.environ: # using Summit cluster |
| 113 | world_size = int(os.environ["OMPI_COMM_WORLD_SIZE"]) |
| 114 | world_rank = int(os.environ["OMPI_COMM_WORLD_RANK"]) |
| 115 | local_rank = int(os.environ["OMPI_COMM_WORLD_LOCAL_RANK"]) |
| 116 | args.local_rank = local_rank |
| 117 | args.rank = world_rank |
| 118 | args.world_size = world_size |
| 119 | torch.distributed.init_process_group( |
| 120 | backend=args.dist_backend, |
| 121 | init_method=args.dist_url, |
| 122 | world_size=args.world_size, |
| 123 | rank=args.rank, |
| 124 | ) |
| 125 | else: |
| 126 | # DDP via torchrun, torch.distributed.launch |
| 127 | args.local_rank, _, _ = world_info_from_env() |
no test coverage detected