()
| 568 | |
| 569 | |
| 570 | def _test_construct_graphs_hetero(): |
| 571 | from dgl.data.csv_dataset_base import ( |
| 572 | DGLGraphConstructor, |
| 573 | EdgeData, |
| 574 | NodeData, |
| 575 | ) |
| 576 | |
| 577 | # node_id/src_id/dst_id could be non-sorted, duplicated, non-numeric. |
| 578 | num_nodes = 100 |
| 579 | num_edges = 1000 |
| 580 | num_dims = 3 |
| 581 | ntypes = ["user", "item"] |
| 582 | node_data = [] |
| 583 | node_ids_dict = {} |
| 584 | ndata_dict = {} |
| 585 | for ntype in ntypes: |
| 586 | node_ids = np.random.choice( |
| 587 | np.arange(num_nodes * 2), size=num_nodes, replace=False |
| 588 | ) |
| 589 | assert len(node_ids) == num_nodes |
| 590 | # to be non-sorted |
| 591 | np.random.shuffle(node_ids) |
| 592 | # to be non-numeric |
| 593 | node_ids = ["id_{}".format(id) for id in node_ids] |
| 594 | t_ndata = { |
| 595 | "feat": np.random.rand(num_nodes, num_dims), |
| 596 | "label": np.random.randint(2, size=num_nodes), |
| 597 | } |
| 598 | _, u_indices = np.unique(node_ids, return_index=True) |
| 599 | ndata = { |
| 600 | "feat": t_ndata["feat"][u_indices], |
| 601 | "label": t_ndata["label"][u_indices], |
| 602 | } |
| 603 | node_data.append(NodeData(node_ids, t_ndata, type=ntype)) |
| 604 | node_ids_dict[ntype] = node_ids |
| 605 | ndata_dict[ntype] = ndata |
| 606 | etypes = [("user", "follow", "user"), ("user", "like", "item")] |
| 607 | edge_data = [] |
| 608 | edata_dict = {} |
| 609 | for src_type, e_type, dst_type in etypes: |
| 610 | src_ids = np.random.choice(node_ids_dict[src_type], size=num_edges) |
| 611 | dst_ids = np.random.choice(node_ids_dict[dst_type], size=num_edges) |
| 612 | edata = { |
| 613 | "feat": np.random.rand(num_edges, num_dims), |
| 614 | "label": np.random.randint(2, size=num_edges), |
| 615 | } |
| 616 | edge_data.append( |
| 617 | EdgeData(src_ids, dst_ids, edata, type=(src_type, e_type, dst_type)) |
| 618 | ) |
| 619 | edata_dict[(src_type, e_type, dst_type)] = edata |
| 620 | graphs, data_dict = DGLGraphConstructor.construct_graphs( |
| 621 | node_data, edge_data |
| 622 | ) |
| 623 | assert len(graphs) == 1 |
| 624 | assert len(data_dict) == 0 |
| 625 | g = graphs[0] |
| 626 | assert not g.is_homogeneous |
| 627 | assert g.num_nodes() == num_nodes * len(ntypes) |
no test coverage detected