MCPcopy
hub / github.com/dmlc/dgl / _test_construct_graphs_multiple

Function _test_construct_graphs_multiple

tests/python/common/data/test_data.py:646–737  ·  view source on GitHub ↗
()

Source from the content-addressed store, hash-verified

644
645
646def _test_construct_graphs_multiple():
647 from dgl.data.csv_dataset_base import (
648 DGLGraphConstructor,
649 EdgeData,
650 GraphData,
651 NodeData,
652 )
653
654 num_nodes = 100
655 num_edges = 1000
656 num_graphs = 10
657 num_dims = 3
658 node_ids = np.array([], dtype=int)
659 src_ids = np.array([], dtype=int)
660 dst_ids = np.array([], dtype=int)
661 ngraph_ids = np.array([], dtype=int)
662 egraph_ids = np.array([], dtype=int)
663 u_indices = np.array([], dtype=int)
664 for i in range(num_graphs):
665 l_node_ids = np.random.choice(
666 np.arange(num_nodes * 2), size=num_nodes, replace=False
667 )
668 node_ids = np.append(node_ids, l_node_ids)
669 _, l_u_indices = np.unique(l_node_ids, return_index=True)
670 u_indices = np.append(u_indices, l_u_indices)
671 ngraph_ids = np.append(ngraph_ids, np.full(num_nodes, i))
672 src_ids = np.append(
673 src_ids, np.random.choice(l_node_ids, size=num_edges)
674 )
675 dst_ids = np.append(
676 dst_ids, np.random.choice(l_node_ids, size=num_edges)
677 )
678 egraph_ids = np.append(egraph_ids, np.full(num_edges, i))
679 ndata = {
680 "feat": np.random.rand(num_nodes * num_graphs, num_dims),
681 "label": np.random.randint(2, size=num_nodes * num_graphs),
682 }
683 ngraph_ids = ["graph_{}".format(id) for id in ngraph_ids]
684 node_data = NodeData(node_ids, ndata, graph_id=ngraph_ids)
685 egraph_ids = ["graph_{}".format(id) for id in egraph_ids]
686 edata = {
687 "feat": np.random.rand(num_edges * num_graphs, num_dims),
688 "label": np.random.randint(2, size=num_edges * num_graphs),
689 }
690 edge_data = EdgeData(src_ids, dst_ids, edata, graph_id=egraph_ids)
691 gdata = {
692 "feat": np.random.rand(num_graphs, num_dims),
693 "label": np.random.randint(2, size=num_graphs),
694 }
695 graph_ids = ["graph_{}".format(id) for id in np.arange(num_graphs)]
696 graph_data = GraphData(graph_ids, gdata)
697 graphs, data_dict = DGLGraphConstructor.construct_graphs(
698 node_data, edge_data, graph_data
699 )
700 assert len(graphs) == num_graphs
701 assert len(data_dict) == len(gdata)
702 for k, v in data_dict.items():
703 assert F.dtype(v) != F.float64

Callers 1

test_csvdatasetFunction · 0.85

Calls 11

NodeDataClass · 0.90
EdgeDataClass · 0.90
GraphDataClass · 0.90
assert_dataFunction · 0.85
appendMethod · 0.80
formatMethod · 0.80
construct_graphsMethod · 0.80
itemsMethod · 0.45
dtypeMethod · 0.45
num_nodesMethod · 0.45
num_edgesMethod · 0.45

Tested by

no test coverage detected