(args)
| 370 | |
| 371 | |
| 372 | def main(args): |
| 373 | print("--- Distributed node classification with GraphSAGE unsuperised ---") |
| 374 | dgl.distributed.initialize(args.ip_config) |
| 375 | if not args.standalone: |
| 376 | th.distributed.init_process_group(backend="gloo") |
| 377 | g = dgl.distributed.DistGraph(args.graph_name, part_config=args.part_config) |
| 378 | print("rank:", g.rank()) |
| 379 | print("number of edges", g.num_edges()) |
| 380 | |
| 381 | train_eids = dgl.distributed.edge_split( |
| 382 | th.ones((g.num_edges(),), dtype=th.bool), |
| 383 | g.get_partition_book(), |
| 384 | force_even=True, |
| 385 | ) |
| 386 | train_nids = dgl.distributed.node_split( |
| 387 | th.ones((g.num_nodes(),), dtype=th.bool), g.get_partition_book() |
| 388 | ) |
| 389 | global_train_nid = th.LongTensor( |
| 390 | np.nonzero(g.ndata["train_mask"][np.arange(g.num_nodes())]) |
| 391 | ) |
| 392 | global_valid_nid = th.LongTensor( |
| 393 | np.nonzero(g.ndata["val_mask"][np.arange(g.num_nodes())]) |
| 394 | ) |
| 395 | global_test_nid = th.LongTensor( |
| 396 | np.nonzero(g.ndata["test_mask"][np.arange(g.num_nodes())]) |
| 397 | ) |
| 398 | labels = g.ndata["labels"][np.arange(g.num_nodes())] |
| 399 | if args.num_gpus == -1: |
| 400 | device = th.device("cpu") |
| 401 | else: |
| 402 | dev_id = g.rank() % args.num_gpus |
| 403 | device = th.device("cuda:" + str(dev_id)) |
| 404 | |
| 405 | # Pack data |
| 406 | in_feats = g.ndata["features"].shape[1] |
| 407 | global_train_nid = global_train_nid.squeeze() |
| 408 | global_valid_nid = global_valid_nid.squeeze() |
| 409 | global_test_nid = global_test_nid.squeeze() |
| 410 | print("number of train {}".format(global_train_nid.shape[0])) |
| 411 | print("number of valid {}".format(global_valid_nid.shape[0])) |
| 412 | print("number of test {}".format(global_test_nid.shape[0])) |
| 413 | data = ( |
| 414 | train_eids, |
| 415 | train_nids, |
| 416 | in_feats, |
| 417 | g, |
| 418 | global_train_nid, |
| 419 | global_valid_nid, |
| 420 | global_test_nid, |
| 421 | labels, |
| 422 | ) |
| 423 | run(args, device, data) |
| 424 | print("parent ends") |
| 425 | |
| 426 | |
| 427 | if __name__ == "__main__": |
no test coverage detected