MCPcopy Index your code
hub / github.com/dmlc/dgl / test_to_networkx

Function test_to_networkx

tests/python/common/test_convert.py:47–122  ·  view source on GitHub ↗
(idtype)

Source from the content-addressed store, hash-verified

45)
46@parametrize_idtype
47def test_to_networkx(idtype):
48 # TODO: adapt and move code from the _test_nx_conversion function in
49 # tests/python/common/function/test_basics.py to here
50 # (pending resolution of https://github.com/dmlc/dgl/issues/5735).
51 g = dgl.heterograph(
52 {
53 ("user", "follows", "user"): ([0, 1], [1, 2]),
54 ("user", "follows", "topic"): ([1, 1], [1, 2]),
55 ("user", "plays", "game"): ([0, 3], [3, 4]),
56 },
57 idtype=idtype,
58 device=F.ctx(),
59 )
60
61 n1 = F.randn((5, 3))
62 n2 = F.randn((4, 2))
63 e1 = F.randn((2, 3))
64 e2 = F.randn((2, 2))
65
66 g.nodes["game"].data["n"] = F.copy_to(n1, ctx=F.ctx())
67 g.nodes["user"].data["n"] = F.copy_to(n2, ctx=F.ctx())
68 g.edges[("user", "follows", "user")].data["e"] = F.copy_to(e1, ctx=F.ctx())
69 g.edges["plays"].data["e"] = F.copy_to(e2, ctx=F.ctx())
70
71 nxg = dgl.to_networkx(
72 g,
73 node_attrs=["n"],
74 edge_attrs=["e"],
75 )
76
77 # Test nodes
78 nxg_nodes = dict(nxg.nodes(data=True))
79 assert len(nxg_nodes) == g.num_nodes()
80 assert {v["ntype"] for v in nxg_nodes.values()} == set(g.ntypes)
81
82 nxg_nodes_by_ntype = {}
83 for ntype in g.ntypes:
84 nxg_nodes_by_ntype[ntype] = get_nodes_by_ntype(nxg_nodes, ntype)
85 assert g.num_nodes(ntype) == len(nxg_nodes_by_ntype[ntype])
86
87 assert check_attrs_for_nodes(nxg_nodes_by_ntype["game"], {"ntype", "n"})
88 assert check_attr_values_for_nodes(nxg_nodes_by_ntype["game"], "n", n1)
89 assert check_attrs_for_nodes(nxg_nodes_by_ntype["user"], {"ntype", "n"})
90 assert check_attr_values_for_nodes(nxg_nodes_by_ntype["user"], "n", n2)
91 # Nodes without node attributes
92 assert check_attrs_for_nodes(nxg_nodes_by_ntype["topic"], {"ntype"})
93
94 # Test edges
95 nxg_edges = list(nxg.edges(data=True))
96 assert len(nxg_edges) == g.num_edges()
97 assert {edge_attrs(e)["etype"] for e in nxg_edges} == set(
98 g.canonical_etypes
99 )
100
101 nxg_edges_by_etype = {}
102 for etype in g.canonical_etypes:
103 nxg_edges_by_etype[etype] = get_edges_by_etype(nxg_edges, etype)
104 assert g.num_edges(etype) == len(nxg_edges_by_etype[etype])

Callers

nothing calls this directly

Calls 15

get_nodes_by_ntypeFunction · 0.85
check_attrs_for_nodesFunction · 0.85
edge_attrsFunction · 0.85
get_edges_by_etypeFunction · 0.85
check_attrs_for_edgesFunction · 0.85
to_networkxMethod · 0.80
ctxMethod · 0.45
copy_toMethod · 0.45
nodesMethod · 0.45
num_nodesMethod · 0.45

Tested by

no test coverage detected