Initialize torch.distributed.
(args)
| 313 | |
| 314 | |
| 315 | def initialize_distributed(args): |
| 316 | """Initialize torch.distributed.""" |
| 317 | |
| 318 | # Manually set the device ids. |
| 319 | device = args.rank % torch.cuda.device_count() |
| 320 | if args.local_rank is not None: |
| 321 | device = args.local_rank |
| 322 | torch.cuda.set_device(device) |
| 323 | # Call the init process |
| 324 | init_method = 'tcp://' |
| 325 | master_ip = os.getenv('MASTER_ADDR', 'localhost') |
| 326 | master_port = os.getenv('MASTER_PORT', '6000') |
| 327 | init_method += master_ip + ':' + master_port |
| 328 | torch.distributed.init_process_group( |
| 329 | backend=args.distributed_backend, |
| 330 | world_size=args.world_size, rank=args.rank, |
| 331 | init_method=init_method) |
| 332 | |
| 333 | # Set the model-parallel / data-parallel communicators. |
| 334 | mpu.initialize_model_parallel(args.model_parallel_size) |
| 335 | |
| 336 | |
| 337 | def set_random_seed(seed): |