(kvargs)
| 78 | |
| 79 | |
| 80 | def init_distributed_env(kvargs): |
| 81 | assert kvargs["world_size"] % kvargs["args"].nnodes == 0, "world_size should be divided by nnodes" |
| 82 | node_world_size = kvargs["world_size"] // kvargs["args"].nnodes |
| 83 | |
| 84 | set_global_rank(kvargs["rank_id"]) |
| 85 | set_global_world_size(kvargs["world_size"]) |
| 86 | set_dp_size(kvargs.get("dp_size", 1)) |
| 87 | set_dp_world_size(get_global_world_size() // get_dp_size()) |
| 88 | set_global_dp_rank(get_global_rank() // get_dp_world_size()) |
| 89 | set_current_rank_in_dp(get_global_rank() % get_dp_world_size()) |
| 90 | set_current_rank_in_node(get_global_rank() % node_world_size) |
| 91 | set_node_world_size(node_world_size) |
| 92 | |
| 93 | nnodes = kvargs["args"].nnodes |
| 94 | dp_size_in_node = max(1, get_dp_size() // nnodes) |
| 95 | set_dp_rank_in_node(get_global_dp_rank() % dp_size_in_node) |
| 96 | |
| 97 | _init_nccl_env() |
| 98 | device_id = kvargs["rank_id"] % get_node_world_size() |
| 99 | set_current_device_id(device_id) |
| 100 | torch.cuda.set_device(device_id) |
| 101 | dist.init_process_group( |
| 102 | "nccl", |
| 103 | init_method=f'tcp://{kvargs["nccl_host"]}:{kvargs["nccl_port"]}', |
| 104 | rank=kvargs["rank_id"], |
| 105 | world_size=kvargs["world_size"], |
| 106 | ) |
| 107 | # warmup nccl communicator |
| 108 | _a = torch.zeros([1]).to(f"cuda:{device_id}") |
| 109 | dist.all_reduce(_a) |
| 110 | del _a |
| 111 | |
| 112 | |
| 113 | def set_global_rank(global_rank: int): |
no test coverage detected