Initialize torch.distributed.
(args)
| 513 | |
| 514 | |
| 515 | def initialize_distributed(args): |
| 516 | """Initialize torch.distributed.""" |
| 517 | |
| 518 | # Manually set the device ids. |
| 519 | device = args.rank % torch.cuda.device_count() |
| 520 | if args.local_rank is not None: |
| 521 | device = args.local_rank |
| 522 | torch.cuda.set_device(device) |
| 523 | # Call the init process |
| 524 | init_method = 'tcp://' |
| 525 | master_ip = os.getenv('MASTER_ADDR', 'localhost') |
| 526 | master_port = os.getenv('MASTER_PORT', '6000') |
| 527 | init_method += master_ip + ':' + master_port |
| 528 | torch.distributed.init_process_group( |
| 529 | backend=args.distributed_backend, |
| 530 | world_size=args.world_size, rank=args.rank, |
| 531 | init_method=init_method) |
| 532 | |
| 533 | # Set the model-parallel / data-parallel communicators. |
| 534 | mpu.initialize_model_parallel(args.model_parallel_size) |
| 535 | |
| 536 | #Checkpoints are partitioned across the model parallel process |
| 537 | #instead of having replicas in the original megatron |
| 538 | mpu.partition_activations_in_checkpoint(args.partition_activations) |
| 539 | |
| 540 | |
| 541 |
no test coverage detected