()
| 644 | |
| 645 | |
| 646 | def _test_construct_graphs_multiple(): |
| 647 | from dgl.data.csv_dataset_base import ( |
| 648 | DGLGraphConstructor, |
| 649 | EdgeData, |
| 650 | GraphData, |
| 651 | NodeData, |
| 652 | ) |
| 653 | |
| 654 | num_nodes = 100 |
| 655 | num_edges = 1000 |
| 656 | num_graphs = 10 |
| 657 | num_dims = 3 |
| 658 | node_ids = np.array([], dtype=int) |
| 659 | src_ids = np.array([], dtype=int) |
| 660 | dst_ids = np.array([], dtype=int) |
| 661 | ngraph_ids = np.array([], dtype=int) |
| 662 | egraph_ids = np.array([], dtype=int) |
| 663 | u_indices = np.array([], dtype=int) |
| 664 | for i in range(num_graphs): |
| 665 | l_node_ids = np.random.choice( |
| 666 | np.arange(num_nodes * 2), size=num_nodes, replace=False |
| 667 | ) |
| 668 | node_ids = np.append(node_ids, l_node_ids) |
| 669 | _, l_u_indices = np.unique(l_node_ids, return_index=True) |
| 670 | u_indices = np.append(u_indices, l_u_indices) |
| 671 | ngraph_ids = np.append(ngraph_ids, np.full(num_nodes, i)) |
| 672 | src_ids = np.append( |
| 673 | src_ids, np.random.choice(l_node_ids, size=num_edges) |
| 674 | ) |
| 675 | dst_ids = np.append( |
| 676 | dst_ids, np.random.choice(l_node_ids, size=num_edges) |
| 677 | ) |
| 678 | egraph_ids = np.append(egraph_ids, np.full(num_edges, i)) |
| 679 | ndata = { |
| 680 | "feat": np.random.rand(num_nodes * num_graphs, num_dims), |
| 681 | "label": np.random.randint(2, size=num_nodes * num_graphs), |
| 682 | } |
| 683 | ngraph_ids = ["graph_{}".format(id) for id in ngraph_ids] |
| 684 | node_data = NodeData(node_ids, ndata, graph_id=ngraph_ids) |
| 685 | egraph_ids = ["graph_{}".format(id) for id in egraph_ids] |
| 686 | edata = { |
| 687 | "feat": np.random.rand(num_edges * num_graphs, num_dims), |
| 688 | "label": np.random.randint(2, size=num_edges * num_graphs), |
| 689 | } |
| 690 | edge_data = EdgeData(src_ids, dst_ids, edata, graph_id=egraph_ids) |
| 691 | gdata = { |
| 692 | "feat": np.random.rand(num_graphs, num_dims), |
| 693 | "label": np.random.randint(2, size=num_graphs), |
| 694 | } |
| 695 | graph_ids = ["graph_{}".format(id) for id in np.arange(num_graphs)] |
| 696 | graph_data = GraphData(graph_ids, gdata) |
| 697 | graphs, data_dict = DGLGraphConstructor.construct_graphs( |
| 698 | node_data, edge_data, graph_data |
| 699 | ) |
| 700 | assert len(graphs) == num_graphs |
| 701 | assert len(data_dict) == len(gdata) |
| 702 | for k, v in data_dict.items(): |
| 703 | assert F.dtype(v) != F.float64 |
no test coverage detected