Submit distributed jobs (server and client processes) via ssh
(args, udf_command, dry_run=False)
| 502 | |
| 503 | |
| 504 | def submit_jobs(args, udf_command, dry_run=False): |
| 505 | """Submit distributed jobs (server and client processes) via ssh""" |
| 506 | if dry_run: |
| 507 | print( |
| 508 | "Currently it's in dry run mode which means no jobs will be launched." |
| 509 | ) |
| 510 | servers_cmd = [] |
| 511 | clients_cmd = [] |
| 512 | hosts = [] |
| 513 | thread_list = [] |
| 514 | server_count_per_machine = 0 |
| 515 | |
| 516 | # Get the IP addresses of the cluster. |
| 517 | ip_config = os.path.join(args.workspace, args.ip_config) |
| 518 | with open(ip_config) as f: |
| 519 | for line in f: |
| 520 | result = line.strip().split() |
| 521 | if len(result) == 2: |
| 522 | ip = result[0] |
| 523 | port = int(result[1]) |
| 524 | hosts.append((ip, port)) |
| 525 | elif len(result) == 1: |
| 526 | ip = result[0] |
| 527 | port = get_available_port(ip) |
| 528 | hosts.append((ip, port)) |
| 529 | else: |
| 530 | raise RuntimeError("Format error of ip_config.") |
| 531 | server_count_per_machine = args.num_servers |
| 532 | # Get partition info of the graph data |
| 533 | part_config = os.path.join(args.workspace, args.part_config) |
| 534 | with open(part_config) as conf_f: |
| 535 | part_metadata = json.load(conf_f) |
| 536 | assert "num_parts" in part_metadata, "num_parts does not exist." |
| 537 | # The number of partitions must match the number of machines in the cluster. |
| 538 | assert part_metadata["num_parts"] == len( |
| 539 | hosts |
| 540 | ), "The number of graph partitions has to match the number of machines in the cluster." |
| 541 | |
| 542 | state_q = queue.Queue() |
| 543 | tot_num_clients = args.num_trainers * (1 + args.num_samplers) * len(hosts) |
| 544 | # launch server tasks |
| 545 | server_env_vars = construct_dgl_server_env_vars( |
| 546 | num_samplers=args.num_samplers, |
| 547 | num_server_threads=args.num_server_threads, |
| 548 | tot_num_clients=tot_num_clients, |
| 549 | part_config=args.part_config, |
| 550 | ip_config=args.ip_config, |
| 551 | num_servers=args.num_servers, |
| 552 | graph_format=args.graph_format, |
| 553 | pythonpath=os.environ.get("PYTHONPATH", ""), |
| 554 | ) |
| 555 | for i in range(len(hosts) * server_count_per_machine): |
| 556 | ip, _ = hosts[int(i / server_count_per_machine)] |
| 557 | server_env_vars_cur = f"{server_env_vars} DGL_SERVER_ID={i}" |
| 558 | cmd = wrap_cmd_with_local_envvars(udf_command, server_env_vars_cur) |
| 559 | cmd = ( |
| 560 | wrap_cmd_with_extra_envvars(cmd, args.extra_envs) |
| 561 | if len(args.extra_envs) > 0 |