| 18 | __builtin__.print = print |
| 19 | |
| 20 | def init_distributed_mode(args): |
| 21 | if 'RANK' in os.environ and 'WORLD_SIZE' in os.environ: |
| 22 | args.rank = int(os.environ["RANK"]) |
| 23 | args.world_size = int(os.environ['WORLD_SIZE']) |
| 24 | args.gpu = int(os.environ['LOCAL_RANK']) |
| 25 | args.dist_url = 'env://' |
| 26 | os.environ['LOCAL_SIZE'] = str(torch.cuda.device_count()) |
| 27 | elif 'SLURM_PROCID' in os.environ: |
| 28 | proc_id = int(os.environ['SLURM_PROCID']) |
| 29 | ntasks = int(os.environ['SLURM_NTASKS']) |
| 30 | node_list = os.environ['SLURM_NODELIST'] |
| 31 | num_gpus = torch.cuda.device_count() |
| 32 | addr = subprocess.getoutput( |
| 33 | 'scontrol show hostname {} | head -n1'.format(node_list)) |
| 34 | os.environ['MASTER_PORT'] = os.environ.get('MASTER_PORT', '29500') |
| 35 | os.environ['MASTER_ADDR'] = addr |
| 36 | os.environ['WORLD_SIZE'] = str(ntasks) |
| 37 | os.environ['RANK'] = str(proc_id) |
| 38 | os.environ['LOCAL_RANK'] = str(proc_id % num_gpus) |
| 39 | os.environ['LOCAL_SIZE'] = str(num_gpus) |
| 40 | args.dist_url = 'env://' |
| 41 | args.world_size = ntasks |
| 42 | args.rank = proc_id |
| 43 | args.gpu = proc_id % num_gpus |
| 44 | else: |
| 45 | print('Not using distributed mode') |
| 46 | args.distributed = False |
| 47 | return |
| 48 | |
| 49 | args.distributed = True |
| 50 | |
| 51 | torch.cuda.set_device(args.gpu) |
| 52 | args.dist_backend = 'nccl' |
| 53 | print('| distributed init (rank {}): {}'.format( |
| 54 | args.rank, args.dist_url), flush=True) |
| 55 | torch.distributed.init_process_group(backend=args.dist_backend, init_method=args.dist_url, |
| 56 | world_size=args.world_size, rank=args.rank) |
| 57 | torch.distributed.barrier() |
| 58 | setup_for_distributed(args.rank == 0) |