Initialize torch.distributed.
(args)
| 432 | |
| 433 | |
| 434 | def initialize_distributed(args): |
| 435 | """Initialize torch.distributed.""" |
| 436 | |
| 437 | # Manually set the device ids. |
| 438 | device = args.rank % torch.cuda.device_count() |
| 439 | if args.local_rank is not None: |
| 440 | device = args.local_rank |
| 441 | torch.cuda.set_device(device) |
| 442 | # Call the init process |
| 443 | init_method = 'tcp://' |
| 444 | master_ip = os.getenv('MASTER_ADDR', 'localhost') |
| 445 | master_port = os.getenv('MASTER_PORT', '6000') |
| 446 | init_method += master_ip + ':' + master_port |
| 447 | torch.distributed.init_process_group( |
| 448 | backend=args.distributed_backend, |
| 449 | world_size=args.world_size, rank=args.rank, |
| 450 | init_method=init_method) |
| 451 | |
| 452 | # Set the model-parallel / data-parallel communicators. |
| 453 | mpu.initialize_model_parallel(args.model_parallel_size) |
| 454 | |
| 455 | |
| 456 | def set_random_seed(seed): |