()
| 510 | |
| 511 | |
| 512 | def _test_construct_graphs_homo(): |
| 513 | from dgl.data.csv_dataset_base import ( |
| 514 | DGLGraphConstructor, |
| 515 | EdgeData, |
| 516 | NodeData, |
| 517 | ) |
| 518 | |
| 519 | # node_id could be non-sorted, non-numeric. |
| 520 | num_nodes = 100 |
| 521 | num_edges = 1000 |
| 522 | num_dims = 3 |
| 523 | node_ids = np.random.choice( |
| 524 | np.arange(num_nodes * 2), size=num_nodes, replace=False |
| 525 | ) |
| 526 | assert len(node_ids) == num_nodes |
| 527 | # to be non-sorted |
| 528 | np.random.shuffle(node_ids) |
| 529 | # to be non-numeric |
| 530 | node_ids = ["id_{}".format(id) for id in node_ids] |
| 531 | t_ndata = { |
| 532 | "feat": np.random.rand(num_nodes, num_dims), |
| 533 | "label": np.random.randint(2, size=num_nodes), |
| 534 | } |
| 535 | _, u_indices = np.unique(node_ids, return_index=True) |
| 536 | ndata = { |
| 537 | "feat": t_ndata["feat"][u_indices], |
| 538 | "label": t_ndata["label"][u_indices], |
| 539 | } |
| 540 | node_data = NodeData(node_ids, t_ndata) |
| 541 | src_ids = np.random.choice(node_ids, size=num_edges) |
| 542 | dst_ids = np.random.choice(node_ids, size=num_edges) |
| 543 | edata = { |
| 544 | "feat": np.random.rand(num_edges, num_dims), |
| 545 | "label": np.random.randint(2, size=num_edges), |
| 546 | } |
| 547 | edge_data = EdgeData(src_ids, dst_ids, edata) |
| 548 | graphs, data_dict = DGLGraphConstructor.construct_graphs( |
| 549 | node_data, edge_data |
| 550 | ) |
| 551 | assert len(graphs) == 1 |
| 552 | assert len(data_dict) == 0 |
| 553 | g = graphs[0] |
| 554 | assert g.is_homogeneous |
| 555 | assert g.num_nodes() == num_nodes |
| 556 | assert g.num_edges() == num_edges |
| 557 | |
| 558 | def assert_data(lhs, rhs): |
| 559 | for key, value in lhs.items(): |
| 560 | assert key in rhs |
| 561 | assert F.dtype(rhs[key]) != F.float64 |
| 562 | assert F.array_equal( |
| 563 | F.tensor(value, dtype=F.dtype(rhs[key])), rhs[key] |
| 564 | ) |
| 565 | |
| 566 | assert_data(ndata, g.ndata) |
| 567 | assert_data(edata, g.edata) |
| 568 | |
| 569 |
no test coverage detected