Constructs the torch distributed launcher command. Helper function. Args: num_trainers: num_nodes: node_rank: master_addr: master_port: Returns: cmd_str.
(
num_trainers: int,
num_nodes: int,
node_rank: int,
master_addr: str,
master_port: int,
)
| 190 | |
| 191 | |
| 192 | def construct_torch_dist_launcher_cmd( |
| 193 | num_trainers: int, |
| 194 | num_nodes: int, |
| 195 | node_rank: int, |
| 196 | master_addr: str, |
| 197 | master_port: int, |
| 198 | ) -> str: |
| 199 | """Constructs the torch distributed launcher command. |
| 200 | Helper function. |
| 201 | |
| 202 | Args: |
| 203 | num_trainers: |
| 204 | num_nodes: |
| 205 | node_rank: |
| 206 | master_addr: |
| 207 | master_port: |
| 208 | |
| 209 | Returns: |
| 210 | cmd_str. |
| 211 | """ |
| 212 | torch_cmd_template = ( |
| 213 | "-m torch.distributed.launch " |
| 214 | "--nproc_per_node={nproc_per_node} " |
| 215 | "--nnodes={nnodes} " |
| 216 | "--node_rank={node_rank} " |
| 217 | "--master_addr={master_addr} " |
| 218 | "--master_port={master_port}" |
| 219 | ) |
| 220 | return torch_cmd_template.format( |
| 221 | nproc_per_node=num_trainers, |
| 222 | nnodes=num_nodes, |
| 223 | node_rank=node_rank, |
| 224 | master_addr=master_addr, |
| 225 | master_port=master_port, |
| 226 | ) |
| 227 | |
| 228 | |
| 229 | def wrap_udf_in_torch_dist_launcher( |
no test coverage detected