MCPcopy
hub / github.com/dmlc/dgl / test_pull_multi_fallback

Function test_pull_multi_fallback

tests/python/common/test_heterograph-specialization.py:238–303  ·  view source on GitHub ↗
(idtype)

Source from the content-addressed store, hash-verified

236
237@parametrize_idtype
238def test_pull_multi_fallback(idtype):
239 # create a graph with zero in degree nodes
240 g = dgl.graph([])
241 g = g.astype(idtype).to(F.ctx())
242 g.add_nodes(10)
243 for i in range(1, 9):
244 g.add_edges(0, i)
245 g.add_edges(i, 9)
246 g.ndata["h"] = F.randn((10, D))
247 g.edata["w1"] = F.randn((16,))
248 g.edata["w2"] = F.randn((16, D))
249
250 def _mfunc_hxw1(edges):
251 return {"m1": edges.src["h"] * F.unsqueeze(edges.data["w1"], 1)}
252
253 def _mfunc_hxw2(edges):
254 return {"m2": edges.src["h"] * edges.data["w2"]}
255
256 def _rfunc_m1(nodes):
257 return {"o1": F.sum(nodes.mailbox["m1"], 1)}
258
259 def _rfunc_m2(nodes):
260 return {"o2": F.sum(nodes.mailbox["m2"], 1)}
261
262 def _rfunc_m1max(nodes):
263 return {"o3": F.max(nodes.mailbox["m1"], 1)}
264
265 def _afunc(nodes):
266 ret = {}
267 for k, v in nodes.data.items():
268 if k.startswith("o"):
269 ret[k] = 2 * v
270 return ret
271
272 # nodes to pull
273 def _pull_nodes(nodes):
274 # compute ground truth
275 g.pull(nodes, _mfunc_hxw1, _rfunc_m1, _afunc)
276 o1 = g.ndata.pop("o1")
277 g.pull(nodes, _mfunc_hxw2, _rfunc_m2, _afunc)
278 o2 = g.ndata.pop("o2")
279 g.pull(nodes, _mfunc_hxw1, _rfunc_m1max, _afunc)
280 o3 = g.ndata.pop("o3")
281 # v2v spmv
282 g.pull(
283 nodes,
284 fn.u_mul_e("h", "w1", "m1"),
285 fn.sum(msg="m1", out="o1"),
286 _afunc,
287 )
288 assert F.allclose(o1, g.ndata.pop("o1"))
289 # v2v fallback to e2v
290 g.pull(
291 nodes,
292 fn.u_mul_e("h", "w2", "m2"),
293 fn.sum(msg="m2", out="o2"),
294 _afunc,
295 )

Calls 7

_pull_nodesFunction · 0.85
graphMethod · 0.45
toMethod · 0.45
astypeMethod · 0.45
ctxMethod · 0.45
add_nodesMethod · 0.45
add_edgesMethod · 0.45

Tested by

no test coverage detected