Initialize the role of the current process. Each process is associated with a role so that we can determine what function can be invoked in a process. For example, we do not allow some functions in sampler processes. The initialization includes registeration the role of the current
(role)
| 124 | |
| 125 | |
| 126 | def init_role(role): |
| 127 | """Initialize the role of the current process. |
| 128 | |
| 129 | Each process is associated with a role so that we can determine what |
| 130 | function can be invoked in a process. For example, we do not allow some |
| 131 | functions in sampler processes. |
| 132 | |
| 133 | The initialization includes registeration the role of the current process and |
| 134 | get the roles of all client processes. It also computes the rank of all client |
| 135 | processes in a deterministic way so that all clients will have the same rank for |
| 136 | the same client process. |
| 137 | """ |
| 138 | global CUR_ROLE |
| 139 | CUR_ROLE = role |
| 140 | |
| 141 | global PER_ROLE_RANK |
| 142 | global GLOBAL_RANK |
| 143 | global IS_STANDALONE |
| 144 | |
| 145 | if os.environ.get("DGL_DIST_MODE", "standalone") == "standalone": |
| 146 | if role == "default": |
| 147 | GLOBAL_RANK[0] = 0 |
| 148 | PER_ROLE_RANK["default"] = {0: 0} |
| 149 | IS_STANDALONE = True |
| 150 | return |
| 151 | |
| 152 | PER_ROLE_RANK = {} |
| 153 | GLOBAL_RANK = {} |
| 154 | |
| 155 | # Register the current role. This blocks until all clients register themselves. |
| 156 | client_id = rpc.get_rank() |
| 157 | machine_id = rpc.get_machine_id() |
| 158 | request = RegisterRoleRequest(client_id, machine_id, role) |
| 159 | rpc.send_request(0, request) |
| 160 | response = rpc.recv_response() |
| 161 | assert response.msg == REG_ROLE_MSG |
| 162 | |
| 163 | # Get all clients on all machines. |
| 164 | request = GetRoleRequest() |
| 165 | rpc.send_request(0, request) |
| 166 | response = rpc.recv_response() |
| 167 | assert response.msg == GET_ROLE_MSG |
| 168 | |
| 169 | # Here we want to compute a new rank for each client. |
| 170 | # We compute the per-role rank as well as global rank. |
| 171 | # For per-role rank, we ensure that all ranks within a machine is contiguous. |
| 172 | # For global rank, we also ensure that all ranks within a machine are contiguous, |
| 173 | # and all ranks within a role are contiguous. |
| 174 | global_rank = 0 |
| 175 | |
| 176 | # We want to ensure that the global rank of the trainer process starts from 0. |
| 177 | role_names = ["default"] |
| 178 | for role_name in response.role: |
| 179 | if role_name not in role_names: |
| 180 | role_names.append(role_name) |
| 181 | |
| 182 | for role_name in role_names: |
| 183 | # Let's collect the ranks of this role in all machines. |
no test coverage detected