MCPcopy Index your code
hub / github.com/dmlc/dgl / _test_nx_conversion

Function _test_nx_conversion

tests/python/common/function/test_basics.py:200–318  ·  view source on GitHub ↗
()

Source from the content-addressed store, hash-verified

198
199
200def _test_nx_conversion():
201 # check conversion between networkx and DGLGraph
202
203 def _check_nx_feature(nxg, nf, ef):
204 # check node and edge feature of nxg
205 # this is used to check to_networkx
206 num_nodes = len(nxg)
207 num_edges = nxg.size()
208 if num_nodes > 0:
209 node_feat = ddict(list)
210 for nid, attr in nxg.nodes(data=True):
211 assert len(attr) == len(nf)
212 for k in nxg.nodes[nid]:
213 node_feat[k].append(F.unsqueeze(attr[k], 0))
214 for k in node_feat:
215 feat = F.cat(node_feat[k], 0)
216 assert F.allclose(feat, nf[k])
217 else:
218 assert len(nf) == 0
219 if num_edges > 0:
220 edge_feat = ddict(lambda: [0] * num_edges)
221 for u, v, attr in nxg.edges(data=True):
222 assert len(attr) == len(ef) + 1 # extra id
223 eid = attr["id"]
224 for k in ef:
225 edge_feat[k][eid] = F.unsqueeze(attr[k], 0)
226 for k in edge_feat:
227 feat = F.cat(edge_feat[k], 0)
228 assert F.allclose(feat, ef[k])
229 else:
230 assert len(ef) == 0
231
232 n1 = F.randn((5, 3))
233 n2 = F.randn((5, 10))
234 n3 = F.randn((5, 4))
235 e1 = F.randn((4, 5))
236 e2 = F.randn((4, 7))
237 g = dgl.graph(([0, 1, 3, 4], [2, 4, 0, 3]))
238 g.ndata.update({"n1": n1, "n2": n2, "n3": n3})
239 g.edata.update({"e1": e1, "e2": e2})
240
241 # convert to networkx
242 nxg = g.to_networkx(node_attrs=["n1", "n3"], edge_attrs=["e1", "e2"])
243 assert len(nxg) == 5
244 assert nxg.size() == 4
245 _check_nx_feature(nxg, {"n1": n1, "n3": n3}, {"e1": e1, "e2": e2})
246
247 # convert to DGLGraph, nx graph has id in edge feature
248 # use id feature to test non-tensor copy
249 g = dgl.from_networkx(nxg, node_attrs=["n1"], edge_attrs=["e1", "id"])
250 # check graph size
251 assert g.num_nodes() == 5
252 assert g.num_edges() == 4
253 # check number of features
254 # test with existing dglgraph (so existing features should be cleared)
255 assert len(g.ndata) == 1
256 assert len(g.edata) == 2
257 # check feature values

Callers

nothing calls this directly

Calls 15

_check_nx_featureFunction · 0.85
to_networkxMethod · 0.80
appendMethod · 0.80
has_edge_betweenMethod · 0.80
graphMethod · 0.45
updateMethod · 0.45
sizeMethod · 0.45
num_nodesMethod · 0.45
num_edgesMethod · 0.45
astypeMethod · 0.45
copy_toMethod · 0.45
cpuMethod · 0.45

Tested by

no test coverage detected