MCPcopy
hub / github.com/dmlc/dgl / _pull_nodes

Function _pull_nodes

tests/python/common/test_heterograph-specialization.py:273–296  ·  view source on GitHub ↗
(nodes)

Source from the content-addressed store, hash-verified

271
272 # nodes to pull
273 def _pull_nodes(nodes):
274 # compute ground truth
275 g.pull(nodes, _mfunc_hxw1, _rfunc_m1, _afunc)
276 o1 = g.ndata.pop("o1")
277 g.pull(nodes, _mfunc_hxw2, _rfunc_m2, _afunc)
278 o2 = g.ndata.pop("o2")
279 g.pull(nodes, _mfunc_hxw1, _rfunc_m1max, _afunc)
280 o3 = g.ndata.pop("o3")
281 # v2v spmv
282 g.pull(
283 nodes,
284 fn.u_mul_e("h", "w1", "m1"),
285 fn.sum(msg="m1", out="o1"),
286 _afunc,
287 )
288 assert F.allclose(o1, g.ndata.pop("o1"))
289 # v2v fallback to e2v
290 g.pull(
291 nodes,
292 fn.u_mul_e("h", "w2", "m2"),
293 fn.sum(msg="m2", out="o2"),
294 _afunc,
295 )
296 assert F.allclose(o2, g.ndata.pop("o2"))
297
298 # test#1: non-0deg nodes
299 nodes = [1, 2, 9]

Callers 1

test_pull_multi_fallbackFunction · 0.85

Calls 1

pullMethod · 0.45

Tested by

no test coverage detected