Main function.
(args)
| 335 | |
| 336 | |
| 337 | def main(args): |
| 338 | """ |
| 339 | Main function. |
| 340 | """ |
| 341 | host_name = socket.gethostname() |
| 342 | print(f"{host_name}: Initializing DistDGL.") |
| 343 | dgl.distributed.initialize(args.ip_config, use_graphbolt=args.use_graphbolt) |
| 344 | print(f"{host_name}: Initializing PyTorch process group.") |
| 345 | th.distributed.init_process_group(backend=args.backend) |
| 346 | print(f"{host_name}: Initializing DistGraph.") |
| 347 | g = dgl.distributed.DistGraph(args.graph_name, part_config=args.part_config) |
| 348 | print(f"Rank of {host_name}: {g.rank()}") |
| 349 | |
| 350 | # Split train/val/test IDs for each trainer. |
| 351 | pb = g.get_partition_book() |
| 352 | if "trainer_id" in g.ndata: |
| 353 | train_nid = dgl.distributed.node_split( |
| 354 | g.ndata["train_mask"], |
| 355 | pb, |
| 356 | force_even=True, |
| 357 | node_trainer_ids=g.ndata["trainer_id"], |
| 358 | ) |
| 359 | val_nid = dgl.distributed.node_split( |
| 360 | g.ndata["val_mask"], |
| 361 | pb, |
| 362 | force_even=True, |
| 363 | node_trainer_ids=g.ndata["trainer_id"], |
| 364 | ) |
| 365 | test_nid = dgl.distributed.node_split( |
| 366 | g.ndata["test_mask"], |
| 367 | pb, |
| 368 | force_even=True, |
| 369 | node_trainer_ids=g.ndata["trainer_id"], |
| 370 | ) |
| 371 | else: |
| 372 | train_nid = dgl.distributed.node_split( |
| 373 | g.ndata["train_mask"], pb, force_even=True |
| 374 | ) |
| 375 | val_nid = dgl.distributed.node_split( |
| 376 | g.ndata["val_mask"], pb, force_even=True |
| 377 | ) |
| 378 | test_nid = dgl.distributed.node_split( |
| 379 | g.ndata["test_mask"], pb, force_even=True |
| 380 | ) |
| 381 | local_nid = pb.partid2nids(pb.partid).detach().numpy() |
| 382 | num_train_local = len(np.intersect1d(train_nid.numpy(), local_nid)) |
| 383 | num_val_local = len(np.intersect1d(val_nid.numpy(), local_nid)) |
| 384 | num_test_local = len(np.intersect1d(test_nid.numpy(), local_nid)) |
| 385 | print( |
| 386 | f"part {g.rank()}, train: {len(train_nid)} (local: {num_train_local}), " |
| 387 | f"val: {len(val_nid)} (local: {num_val_local}), " |
| 388 | f"test: {len(test_nid)} (local: {num_test_local})" |
| 389 | ) |
| 390 | del local_nid |
| 391 | if args.num_gpus == 0: |
| 392 | device = th.device("cpu") |
| 393 | else: |
| 394 | dev_id = g.rank() % args.num_gpus |
no test coverage detected