(args)
| 421 | |
| 422 | |
| 423 | def init_distributed_mode(args): |
| 424 | if 'RANK' in os.environ and 'WORLD_SIZE' in os.environ: |
| 425 | args.rank = int(os.environ["RANK"]) |
| 426 | args.world_size = int(os.environ['WORLD_SIZE']) |
| 427 | args.gpu = int(os.environ['LOCAL_RANK']) |
| 428 | args.dist_url = 'env://' |
| 429 | os.environ['LOCAL_SIZE'] = str(torch.cuda.device_count()) |
| 430 | elif 'SLURM_PROCID' in os.environ: |
| 431 | proc_id = int(os.environ['SLURM_PROCID']) |
| 432 | ntasks = int(os.environ['SLURM_NTASKS']) |
| 433 | node_list = os.environ['SLURM_NODELIST'] |
| 434 | num_gpus = torch.cuda.device_count() |
| 435 | addr = subprocess.getoutput( |
| 436 | 'scontrol show hostname {} | head -n1'.format(node_list)) |
| 437 | os.environ['MASTER_PORT'] = os.environ.get('MASTER_PORT', '29500') |
| 438 | os.environ['MASTER_ADDR'] = addr |
| 439 | os.environ['WORLD_SIZE'] = str(ntasks) |
| 440 | os.environ['RANK'] = str(proc_id) |
| 441 | os.environ['LOCAL_RANK'] = str(proc_id % num_gpus) |
| 442 | os.environ['LOCAL_SIZE'] = str(num_gpus) |
| 443 | args.dist_url = 'env://' |
| 444 | args.world_size = ntasks |
| 445 | args.rank = proc_id |
| 446 | args.gpu = proc_id % num_gpus |
| 447 | else: |
| 448 | print('Not using distributed mode') |
| 449 | args.distributed = False |
| 450 | return |
| 451 | |
| 452 | args.distributed = True |
| 453 | |
| 454 | torch.cuda.set_device(args.gpu) |
| 455 | args.dist_backend = 'nccl' |
| 456 | print('| distributed init (rank {}): {}'.format( |
| 457 | args.rank, args.dist_url), flush=True) |
| 458 | torch.distributed.init_process_group(backend=args.dist_backend, init_method=args.dist_url, |
| 459 | world_size=args.world_size, rank=args.rank) |
| 460 | torch.distributed.barrier() |
| 461 | setup_for_distributed(args.rank == 0) |
| 462 | |
| 463 | |
| 464 | @torch.no_grad() |
nothing calls this directly
no test coverage detected