MCPcopy
hub / github.com/dmlc/dgl / main

Function main

examples/distributed/graphsage/node_classification_unsupervised.py:372–424  ·  view source on GitHub ↗
(args)

Source from the content-addressed store, hash-verified

370
371
372def main(args):
373 print("--- Distributed node classification with GraphSAGE unsuperised ---")
374 dgl.distributed.initialize(args.ip_config)
375 if not args.standalone:
376 th.distributed.init_process_group(backend="gloo")
377 g = dgl.distributed.DistGraph(args.graph_name, part_config=args.part_config)
378 print("rank:", g.rank())
379 print("number of edges", g.num_edges())
380
381 train_eids = dgl.distributed.edge_split(
382 th.ones((g.num_edges(),), dtype=th.bool),
383 g.get_partition_book(),
384 force_even=True,
385 )
386 train_nids = dgl.distributed.node_split(
387 th.ones((g.num_nodes(),), dtype=th.bool), g.get_partition_book()
388 )
389 global_train_nid = th.LongTensor(
390 np.nonzero(g.ndata["train_mask"][np.arange(g.num_nodes())])
391 )
392 global_valid_nid = th.LongTensor(
393 np.nonzero(g.ndata["val_mask"][np.arange(g.num_nodes())])
394 )
395 global_test_nid = th.LongTensor(
396 np.nonzero(g.ndata["test_mask"][np.arange(g.num_nodes())])
397 )
398 labels = g.ndata["labels"][np.arange(g.num_nodes())]
399 if args.num_gpus == -1:
400 device = th.device("cpu")
401 else:
402 dev_id = g.rank() % args.num_gpus
403 device = th.device("cuda:" + str(dev_id))
404
405 # Pack data
406 in_feats = g.ndata["features"].shape[1]
407 global_train_nid = global_train_nid.squeeze()
408 global_valid_nid = global_valid_nid.squeeze()
409 global_test_nid = global_test_nid.squeeze()
410 print("number of train {}".format(global_train_nid.shape[0]))
411 print("number of valid {}".format(global_valid_nid.shape[0]))
412 print("number of test {}".format(global_test_nid.shape[0]))
413 data = (
414 train_eids,
415 train_nids,
416 in_feats,
417 g,
418 global_train_nid,
419 global_valid_nid,
420 global_test_nid,
421 labels,
422 )
423 run(args, device, data)
424 print("parent ends")
425
426
427if __name__ == "__main__":

Calls 8

rankMethod · 0.95
num_edgesMethod · 0.95
get_partition_bookMethod · 0.95
num_nodesMethod · 0.95
nonzeroMethod · 0.80
formatMethod · 0.80
runFunction · 0.70
deviceMethod · 0.45

Tested by

no test coverage detected