| 32 | |
| 33 | |
| 34 | def get_cluster_from_args(selected_gpus, num_nodes=1): |
| 35 | cluster_node_ips = "127.0.0.1" |
| 36 | node_ip = "127.0.0.1" |
| 37 | |
| 38 | node_ips = [x.strip() for x in cluster_node_ips.split(",")] |
| 39 | |
| 40 | node_ips.index(node_ip) |
| 41 | |
| 42 | free_ports = None |
| 43 | |
| 44 | free_ports = find_free_ports(len(selected_gpus)) |
| 45 | if free_ports is not None: |
| 46 | free_ports = list(free_ports) |
| 47 | |
| 48 | trainer_endpoints = [] |
| 49 | for ip in node_ips: |
| 50 | trainer_endpoints.append(["%s:%d" % (ip, port) for port in free_ports]) |
| 51 | |
| 52 | return get_cluster(node_ips, node_ip, trainer_endpoints, selected_gpus) |
| 53 | |
| 54 | |
| 55 | def get_gpus(selected_gpus): |