Initialize torch.distributed.
(args)
| 641 | |
| 642 | |
| 643 | def initialize_distributed(args): |
| 644 | """Initialize torch.distributed.""" |
| 645 | |
| 646 | # Manually set the device ids. |
| 647 | device = args.rank % torch.cuda.device_count() |
| 648 | if args.local_rank is not None: |
| 649 | device = args.local_rank |
| 650 | torch.cuda.set_device(device) |
| 651 | # Call the init process |
| 652 | init_method = 'tcp://' |
| 653 | master_ip = os.getenv('MASTER_ADDR', 'localhost') |
| 654 | master_port = os.getenv('MASTER_PORT', '6000') |
| 655 | init_method += master_ip + ':' + master_port |
| 656 | torch.distributed.init_process_group( |
| 657 | backend=args.distributed_backend, |
| 658 | world_size=args.world_size, rank=args.rank, |
| 659 | init_method=init_method) |
| 660 | |
| 661 | # Set the model-parallel / data-parallel communicators. |
| 662 | mpu.initialize_model_parallel(args.model_parallel_size) |
| 663 | |
| 664 | # Optional DeepSpeed Activation Checkpointing Features |
| 665 | # |
| 666 | if hasattr(args, "deepspeed") and args.deepspeed and args.deepspeed_activation_checkpointing: |
| 667 | set_deepspeed_activation_checkpointing(args) |
| 668 | |
| 669 | |
| 670 | def set_random_seed(seed): |
no test coverage detected