MCPcopy
hub / github.com/dmlc/dgl / _test_construct_graphs_node_ids

Function _test_construct_graphs_node_ids

tests/python/common/data/test_data.py:430–509  ·  view source on GitHub ↗
()

Source from the content-addressed store, hash-verified

428
429
430def _test_construct_graphs_node_ids():
431 from dgl.data.csv_dataset_base import (
432 DGLGraphConstructor,
433 EdgeData,
434 NodeData,
435 )
436
437 num_nodes = 100
438 num_edges = 1000
439
440 # node IDs are required to be unique
441 node_ids = np.random.choice(np.arange(num_nodes / 2), num_nodes)
442 src_ids = np.random.choice(node_ids, size=num_edges)
443 dst_ids = np.random.choice(node_ids, size=num_edges)
444 node_data = NodeData(node_ids, {})
445 edge_data = EdgeData(src_ids, dst_ids, {})
446 expect_except = False
447 try:
448 _, _ = DGLGraphConstructor.construct_graphs(node_data, edge_data)
449 except:
450 expect_except = True
451 assert expect_except
452
453 # node IDs are already labelled from 0~num_nodes-1
454 node_ids = np.arange(num_nodes)
455 np.random.shuffle(node_ids)
456 _, idx = np.unique(node_ids, return_index=True)
457 src_ids = np.random.choice(node_ids, size=num_edges)
458 dst_ids = np.random.choice(node_ids, size=num_edges)
459 node_feat = np.random.rand(num_nodes, 3)
460 node_data = NodeData(node_ids, {"feat": node_feat})
461 edge_data = EdgeData(src_ids, dst_ids, {})
462 graphs, data_dict = DGLGraphConstructor.construct_graphs(
463 node_data, edge_data
464 )
465 assert len(graphs) == 1
466 assert len(data_dict) == 0
467 g = graphs[0]
468 assert g.is_homogeneous
469 assert g.num_nodes() == len(node_ids)
470 assert g.num_edges() == len(src_ids)
471 assert F.array_equal(
472 F.tensor(node_feat[idx], dtype=F.float32), g.ndata["feat"]
473 )
474
475 # node IDs are mixed with numeric and non-numeric values
476 # homogeneous graph
477 node_ids = [1, 2, 3, "a"]
478 src_ids = [1, 2, 3]
479 dst_ids = ["a", 1, 2]
480 node_data = NodeData(node_ids, {})
481 edge_data = EdgeData(src_ids, dst_ids, {})
482 graphs, data_dict = DGLGraphConstructor.construct_graphs(
483 node_data, edge_data
484 )
485 assert len(graphs) == 1
486 assert len(data_dict) == 0
487 g = graphs[0]

Callers 1

test_csvdatasetFunction · 0.85

Calls 6

NodeDataClass · 0.90
EdgeDataClass · 0.90
construct_graphsMethod · 0.80
shuffleMethod · 0.45
num_nodesMethod · 0.45
num_edgesMethod · 0.45

Tested by

no test coverage detected