()
| 20 | } |
| 21 | |
| 22 | def setup_model_parallel() -> Tuple[int, int]: |
| 23 | local_rank = int(os.environ.get("LOCAL_RANK", -1)) |
| 24 | world_size = int(os.environ.get("WORLD_SIZE", -1)) |
| 25 | if not torch.distributed.is_initialized(): |
| 26 | torch.distributed.init_process_group("nccl") |
| 27 | initialize_model_parallel(world_size) |
| 28 | torch.cuda.set_device(local_rank) |
| 29 | torch.manual_seed(1) |
| 30 | return local_rank, world_size |
| 31 | |
| 32 | |
| 33 | def load( |