(args)
| 214 | |
| 215 | |
| 216 | def init_distributed_mode(args): |
| 217 | if 'RANK' in os.environ and 'WORLD_SIZE' in os.environ: |
| 218 | args.rank = int(os.environ["RANK"]) |
| 219 | args.world_size = int(os.environ['WORLD_SIZE']) |
| 220 | args.gpu = int(os.environ['LOCAL_RANK']) |
| 221 | elif 'SLURM_PROCID' in os.environ: |
| 222 | args.rank = int(os.environ['SLURM_PROCID']) |
| 223 | args.gpu = args.rank % torch.cuda.device_count() |
| 224 | else: |
| 225 | print('Not using distributed mode') |
| 226 | args.distributed = False |
| 227 | return |
| 228 | |
| 229 | args.distributed = True |
| 230 | |
| 231 | torch.cuda.set_device(args.gpu) |
| 232 | args.dist_backend = 'nccl' |
| 233 | print('| distributed init (rank {}): {}'.format( |
| 234 | args.rank, args.dist_url), flush=True) |
| 235 | torch.distributed.init_process_group(backend=args.dist_backend, init_method=args.dist_url, |
| 236 | world_size=args.world_size, rank=args.rank) |
| 237 | torch.distributed.barrier() |
| 238 | setup_for_distributed(args.rank == 0) |
| 239 | |
| 240 | |
| 241 | def update_from_config(args): |
nothing calls this directly
no test coverage detected