MCPcopy
hub / github.com/dmlc/dgl / main

Function main

examples/distributed/graphsage/node_classification.py:337–413  ·  view source on GitHub ↗

Main function.

(args)

Source from the content-addressed store, hash-verified

335
336
337def main(args):
338 """
339 Main function.
340 """
341 host_name = socket.gethostname()
342 print(f"{host_name}: Initializing DistDGL.")
343 dgl.distributed.initialize(args.ip_config, use_graphbolt=args.use_graphbolt)
344 print(f"{host_name}: Initializing PyTorch process group.")
345 th.distributed.init_process_group(backend=args.backend)
346 print(f"{host_name}: Initializing DistGraph.")
347 g = dgl.distributed.DistGraph(args.graph_name, part_config=args.part_config)
348 print(f"Rank of {host_name}: {g.rank()}")
349
350 # Split train/val/test IDs for each trainer.
351 pb = g.get_partition_book()
352 if "trainer_id" in g.ndata:
353 train_nid = dgl.distributed.node_split(
354 g.ndata["train_mask"],
355 pb,
356 force_even=True,
357 node_trainer_ids=g.ndata["trainer_id"],
358 )
359 val_nid = dgl.distributed.node_split(
360 g.ndata["val_mask"],
361 pb,
362 force_even=True,
363 node_trainer_ids=g.ndata["trainer_id"],
364 )
365 test_nid = dgl.distributed.node_split(
366 g.ndata["test_mask"],
367 pb,
368 force_even=True,
369 node_trainer_ids=g.ndata["trainer_id"],
370 )
371 else:
372 train_nid = dgl.distributed.node_split(
373 g.ndata["train_mask"], pb, force_even=True
374 )
375 val_nid = dgl.distributed.node_split(
376 g.ndata["val_mask"], pb, force_even=True
377 )
378 test_nid = dgl.distributed.node_split(
379 g.ndata["test_mask"], pb, force_even=True
380 )
381 local_nid = pb.partid2nids(pb.partid).detach().numpy()
382 num_train_local = len(np.intersect1d(train_nid.numpy(), local_nid))
383 num_val_local = len(np.intersect1d(val_nid.numpy(), local_nid))
384 num_test_local = len(np.intersect1d(test_nid.numpy(), local_nid))
385 print(
386 f"part {g.rank()}, train: {len(train_nid)} (local: {num_train_local}), "
387 f"val: {len(val_nid)} (local: {num_val_local}), "
388 f"test: {len(test_nid)} (local: {num_test_local})"
389 )
390 del local_nid
391 if args.num_gpus == 0:
392 device = th.device("cpu")
393 else:
394 dev_id = g.rank() % args.num_gpus

Callers 1

Calls 6

rankMethod · 0.95
get_partition_bookMethod · 0.95
num_nodesMethod · 0.95
runFunction · 0.70
partid2nidsMethod · 0.45
deviceMethod · 0.45

Tested by

no test coverage detected