MCPcopy
hub / github.com/dmlc/dgl / test_subgraph

Function test_subgraph

tests/python/common/test_heterograph.py:1609–1800  ·  view source on GitHub ↗
(idtype)

Source from the content-addressed store, hash-verified

1607
1608@parametrize_idtype
1609def test_subgraph(idtype):
1610 g = create_test_heterograph(idtype)
1611 g_graph = g["follows"]
1612 g_bipartite = g["plays"]
1613
1614 x = F.randn((3, 5))
1615 y = F.randn((2, 4))
1616 g.nodes["user"].data["h"] = x
1617 g.edges["follows"].data["h"] = y
1618
1619 def _check_subgraph(g, sg):
1620 assert sg.idtype == g.idtype
1621 assert sg.device == g.device
1622 assert sg.ntypes == g.ntypes
1623 assert sg.etypes == g.etypes
1624 assert sg.canonical_etypes == g.canonical_etypes
1625 assert F.array_equal(
1626 F.tensor(sg.nodes["user"].data[dgl.NID]), F.tensor([1, 2], g.idtype)
1627 )
1628 assert F.array_equal(
1629 F.tensor(sg.nodes["game"].data[dgl.NID]), F.tensor([0], g.idtype)
1630 )
1631 assert F.array_equal(
1632 F.tensor(sg.edges["follows"].data[dgl.EID]), F.tensor([1], g.idtype)
1633 )
1634 assert F.array_equal(
1635 F.tensor(sg.edges["plays"].data[dgl.EID]), F.tensor([1], g.idtype)
1636 )
1637 assert F.array_equal(
1638 F.tensor(sg.edges["wishes"].data[dgl.EID]), F.tensor([1], g.idtype)
1639 )
1640 assert sg.num_nodes("developer") == 0
1641 assert sg.num_edges("develops") == 0
1642 assert F.array_equal(
1643 sg.nodes["user"].data["h"], g.nodes["user"].data["h"][1:3]
1644 )
1645 assert F.array_equal(
1646 sg.edges["follows"].data["h"], g.edges["follows"].data["h"][1:2]
1647 )
1648
1649 sg1 = g.subgraph({"user": [1, 2], "game": [0]})
1650 _check_subgraph(g, sg1)
1651 if F._default_context_str != "gpu":
1652 # TODO(minjie): enable this later
1653 sg2 = g.edge_subgraph({"follows": [1], "plays": [1], "wishes": [1]})
1654 _check_subgraph(g, sg2)
1655
1656 # backend tensor input
1657 sg1 = g.subgraph(
1658 {
1659 "user": F.tensor([1, 2], dtype=idtype),
1660 "game": F.tensor([0], dtype=idtype),
1661 }
1662 )
1663 _check_subgraph(g, sg1)
1664 if F._default_context_str != "gpu":
1665 # TODO(minjie): enable this later
1666 sg2 = g.edge_subgraph(

Callers

nothing calls this directly

Calls 7

create_test_heterographFunction · 0.70
_check_subgraphFunction · 0.70
_check_typed_subgraph1Function · 0.70
_check_typed_subgraph2Function · 0.70
edge_subgraphMethod · 0.45

Tested by

no test coverage detected