(args)
| 269 | |
| 270 | |
| 271 | def init_distributed_mode(args): |
| 272 | |
| 273 | if args.dist_on_itp: |
| 274 | args.rank = int(os.environ['OMPI_COMM_WORLD_RANK']) |
| 275 | args.world_size = int(os.environ['OMPI_COMM_WORLD_SIZE']) |
| 276 | args.gpu = int(os.environ['OMPI_COMM_WORLD_LOCAL_RANK']) |
| 277 | args.dist_url = "tcp://%s:%s" % (os.environ['MASTER_ADDR'], os.environ['MASTER_PORT']) |
| 278 | os.environ['LOCAL_RANK'] = str(args.gpu) |
| 279 | os.environ['RANK'] = str(args.rank) |
| 280 | os.environ['WORLD_SIZE'] = str(args.world_size) |
| 281 | # ["RANK", "WORLD_SIZE", "MASTER_ADDR", "MASTER_PORT", "LOCAL_RANK"] |
| 282 | elif 'RANK' in os.environ and 'WORLD_SIZE' in os.environ: |
| 283 | args.rank = int(os.environ["RANK"]) |
| 284 | args.world_size = int(os.environ['WORLD_SIZE']) |
| 285 | args.gpu = int(os.environ['LOCAL_RANK']) |
| 286 | elif 'SLURM_PROCID' in os.environ: |
| 287 | args.rank = int(os.environ['SLURM_PROCID']) |
| 288 | args.gpu = args.rank % torch.cuda.device_count() |
| 289 | |
| 290 | os.environ['RANK'] = str(args.rank) |
| 291 | os.environ['LOCAL_RANK'] = str(args.gpu) |
| 292 | os.environ['WORLD_SIZE'] = str(args.world_size) |
| 293 | else: |
| 294 | print('Not using distributed mode') |
| 295 | args.distributed = False |
| 296 | return |
| 297 | |
| 298 | args.distributed = True |
| 299 | |
| 300 | torch.cuda.set_device(args.gpu) |
| 301 | args.dist_backend = 'nccl' |
| 302 | print('| distributed init (rank {}): {}, gpu {}'.format( |
| 303 | args.rank, args.dist_url, args.gpu), flush=True) |
| 304 | torch.distributed.init_process_group(backend=args.dist_backend, init_method=args.dist_url, |
| 305 | world_size=args.world_size, rank=args.rank) |
| 306 | torch.distributed.barrier() |
| 307 | setup_for_distributed(args.rank == 0) |
| 308 | |
| 309 | |
| 310 | def load_state_dict(model, state_dict, prefix='', ignore_missing="relative_position_index"): |
nothing calls this directly
no test coverage detected