MCPcopy
hub / github.com/dmlc/dgl / _test_NodeEdgeGraphData

Function _test_NodeEdgeGraphData

tests/python/common/data/test_data.py:1543–1629  ·  view source on GitHub ↗
()

Source from the content-addressed store, hash-verified

1541
1542
1543def _test_NodeEdgeGraphData():
1544 from dgl.data.csv_dataset_base import EdgeData, GraphData, NodeData
1545
1546 # NodeData basics
1547 num_nodes = 100
1548 node_ids = np.arange(num_nodes, dtype=float)
1549 ndata = NodeData(node_ids, {})
1550 assert np.array_equal(ndata.id, node_ids)
1551 assert len(ndata.data) == 0
1552 assert ndata.type == "_V"
1553 assert np.array_equal(ndata.graph_id, np.full(num_nodes, 0))
1554 # NodeData more
1555 data = {"feat": np.random.rand(num_nodes, 3)}
1556 graph_id = np.arange(num_nodes)
1557 ndata = NodeData(node_ids, data, type="user", graph_id=graph_id)
1558 assert ndata.type == "user"
1559 assert np.array_equal(ndata.graph_id, graph_id)
1560 assert len(ndata.data) == len(data)
1561 for k, v in data.items():
1562 assert k in ndata.data
1563 assert np.array_equal(ndata.data[k], v)
1564 # NodeData except
1565 expect_except = False
1566 try:
1567 NodeData(
1568 np.arange(num_nodes),
1569 {"feat": np.random.rand(num_nodes + 1, 3)},
1570 graph_id=np.arange(num_nodes - 1),
1571 )
1572 except:
1573 expect_except = True
1574 assert expect_except
1575
1576 # EdgeData basics
1577 num_nodes = 100
1578 num_edges = 1000
1579 src_ids = np.random.randint(num_nodes, size=num_edges)
1580 dst_ids = np.random.randint(num_nodes, size=num_edges)
1581 edata = EdgeData(src_ids, dst_ids, {})
1582 assert np.array_equal(edata.src, src_ids)
1583 assert np.array_equal(edata.dst, dst_ids)
1584 assert edata.type == ("_V", "_E", "_V")
1585 assert len(edata.data) == 0
1586 assert np.array_equal(edata.graph_id, np.full(num_edges, 0))
1587 # EdageData more
1588 src_ids = np.random.randint(num_nodes, size=num_edges).astype(float)
1589 dst_ids = np.random.randint(num_nodes, size=num_edges).astype(float)
1590 data = {"feat": np.random.rand(num_edges, 3)}
1591 etype = ("user", "like", "item")
1592 graph_ids = np.arange(num_edges)
1593 edata = EdgeData(src_ids, dst_ids, data, type=etype, graph_id=graph_ids)
1594 assert np.array_equal(edata.src, src_ids)
1595 assert np.array_equal(edata.dst, dst_ids)
1596 assert edata.type == etype
1597 assert len(edata.data) == len(data)
1598 for k, v in data.items():
1599 assert k in edata.data
1600 assert np.array_equal(edata.data[k], v)

Callers 1

test_csvdatasetFunction · 0.85

Calls 5

NodeDataClass · 0.90
EdgeDataClass · 0.90
GraphDataClass · 0.90
itemsMethod · 0.45
astypeMethod · 0.45

Tested by

no test coverage detected