| 1541 | |
| 1542 | |
| 1543 | def _test_NodeEdgeGraphData(): |
| 1544 | from dgl.data.csv_dataset_base import EdgeData, GraphData, NodeData |
| 1545 | |
| 1546 | # NodeData basics |
| 1547 | num_nodes = 100 |
| 1548 | node_ids = np.arange(num_nodes, dtype=float) |
| 1549 | ndata = NodeData(node_ids, {}) |
| 1550 | assert np.array_equal(ndata.id, node_ids) |
| 1551 | assert len(ndata.data) == 0 |
| 1552 | assert ndata.type == "_V" |
| 1553 | assert np.array_equal(ndata.graph_id, np.full(num_nodes, 0)) |
| 1554 | # NodeData more |
| 1555 | data = {"feat": np.random.rand(num_nodes, 3)} |
| 1556 | graph_id = np.arange(num_nodes) |
| 1557 | ndata = NodeData(node_ids, data, type="user", graph_id=graph_id) |
| 1558 | assert ndata.type == "user" |
| 1559 | assert np.array_equal(ndata.graph_id, graph_id) |
| 1560 | assert len(ndata.data) == len(data) |
| 1561 | for k, v in data.items(): |
| 1562 | assert k in ndata.data |
| 1563 | assert np.array_equal(ndata.data[k], v) |
| 1564 | # NodeData except |
| 1565 | expect_except = False |
| 1566 | try: |
| 1567 | NodeData( |
| 1568 | np.arange(num_nodes), |
| 1569 | {"feat": np.random.rand(num_nodes + 1, 3)}, |
| 1570 | graph_id=np.arange(num_nodes - 1), |
| 1571 | ) |
| 1572 | except: |
| 1573 | expect_except = True |
| 1574 | assert expect_except |
| 1575 | |
| 1576 | # EdgeData basics |
| 1577 | num_nodes = 100 |
| 1578 | num_edges = 1000 |
| 1579 | src_ids = np.random.randint(num_nodes, size=num_edges) |
| 1580 | dst_ids = np.random.randint(num_nodes, size=num_edges) |
| 1581 | edata = EdgeData(src_ids, dst_ids, {}) |
| 1582 | assert np.array_equal(edata.src, src_ids) |
| 1583 | assert np.array_equal(edata.dst, dst_ids) |
| 1584 | assert edata.type == ("_V", "_E", "_V") |
| 1585 | assert len(edata.data) == 0 |
| 1586 | assert np.array_equal(edata.graph_id, np.full(num_edges, 0)) |
| 1587 | # EdageData more |
| 1588 | src_ids = np.random.randint(num_nodes, size=num_edges).astype(float) |
| 1589 | dst_ids = np.random.randint(num_nodes, size=num_edges).astype(float) |
| 1590 | data = {"feat": np.random.rand(num_edges, 3)} |
| 1591 | etype = ("user", "like", "item") |
| 1592 | graph_ids = np.arange(num_edges) |
| 1593 | edata = EdgeData(src_ids, dst_ids, data, type=etype, graph_id=graph_ids) |
| 1594 | assert np.array_equal(edata.src, src_ids) |
| 1595 | assert np.array_equal(edata.dst, dst_ids) |
| 1596 | assert edata.type == etype |
| 1597 | assert len(edata.data) == len(data) |
| 1598 | for k, v in data.items(): |
| 1599 | assert k in edata.data |
| 1600 | assert np.array_equal(edata.data[k], v) |