(args)
| 247 | |
| 248 | |
| 249 | def init_distributed_mode(args): |
| 250 | if args.dist_on_itp: |
| 251 | args.rank = int(os.environ['OMPI_COMM_WORLD_RANK']) |
| 252 | args.world_size = int(os.environ['OMPI_COMM_WORLD_SIZE']) |
| 253 | args.gpu = int(os.environ['OMPI_COMM_WORLD_LOCAL_RANK']) |
| 254 | args.dist_url = "tcp://%s:%s" % (os.environ['MASTER_ADDR'], os.environ['MASTER_PORT']) |
| 255 | os.environ['LOCAL_RANK'] = str(args.gpu) |
| 256 | os.environ['RANK'] = str(args.rank) |
| 257 | os.environ['WORLD_SIZE'] = str(args.world_size) |
| 258 | elif 'SLURM_PROCID' in os.environ: |
| 259 | args.rank = int(os.environ['SLURM_PROCID']) |
| 260 | args.gpu = int(os.environ['SLURM_LOCALID']) |
| 261 | args.world_size = int(os.environ['SLURM_NTASKS']) |
| 262 | os.environ['RANK'] = str(args.rank) |
| 263 | os.environ['LOCAL_RANK'] = str(args.gpu) |
| 264 | os.environ['WORLD_SIZE'] = str(args.world_size) |
| 265 | |
| 266 | node_list = os.environ['SLURM_NODELIST'] |
| 267 | addr = subprocess.getoutput( |
| 268 | f'scontrol show hostname {node_list} | head -n1') |
| 269 | if 'MASTER_ADDR' not in os.environ: |
| 270 | os.environ['MASTER_ADDR'] = addr |
| 271 | elif 'RANK' in os.environ and 'WORLD_SIZE' in os.environ: |
| 272 | args.rank = int(os.environ["RANK"]) |
| 273 | args.world_size = int(os.environ['WORLD_SIZE']) |
| 274 | args.gpu = int(os.environ['LOCAL_RANK']) |
| 275 | else: |
| 276 | print('Not using distributed mode') |
| 277 | args.distributed = False |
| 278 | return |
| 279 | |
| 280 | args.distributed = True |
| 281 | |
| 282 | torch.cuda.set_device(args.gpu) |
| 283 | args.dist_backend = 'nccl' |
| 284 | print('| distributed init (rank {}): {}, gpu {}'.format( |
| 285 | args.rank, args.dist_url, args.gpu), flush=True) |
| 286 | torch.distributed.init_process_group(backend=args.dist_backend, init_method=args.dist_url, |
| 287 | world_size=args.world_size, rank=args.rank) |
| 288 | torch.distributed.barrier() |
| 289 | # assert torch.distributed.is_initialized() |
| 290 | setup_for_distributed(args.rank == 0) |
| 291 | |
| 292 | |
| 293 | def load_state_dict(model, state_dict, prefix='', ignore_missing="relative_position_index"): |
nothing calls this directly
no test coverage detected