MCPcopy Index your code
hub / github.com/dmlc/dgl / test_pull_0deg

Function test_pull_0deg

tests/python/common/function/test_basics.py:434–473  ·  view source on GitHub ↗
(idtype)

Source from the content-addressed store, hash-verified

432
433@parametrize_idtype
434def test_pull_0deg(idtype):
435 g = dgl.graph(([0], [1]), idtype=idtype, device=F.ctx())
436
437 def _message(edges):
438 return {"m": edges.src["h"]}
439
440 def _reduce(nodes):
441 return {"x": nodes.data["h"] + F.sum(nodes.mailbox["m"], 1)}
442
443 def _apply(nodes):
444 return {"x": nodes.data["x"] * 2}
445
446 def _init2(shape, dtype, ctx, ids):
447 return 2 + F.zeros(shape, dtype, ctx)
448
449 g.set_n_initializer(_init2, "x")
450 # test#1: pull both 0deg and non-0deg nodes
451 old = F.randn((2, 5))
452 g.ndata["h"] = old
453 g.pull([0, 1], _message, _reduce, _apply)
454 new = g.ndata["x"]
455 # 0deg check: initialized with the func and got applied
456 assert F.allclose(new[0], F.full_1d(5, 4, dtype=F.float32))
457 # non-0deg check
458 assert F.allclose(new[1], F.sum(old, 0) * 2)
459
460 # test#2: pull only 0deg node
461 old = F.randn((2, 5))
462 g.ndata["h"] = old
463 # Intercepting the warning: The input graph for the user-defined edge
464 # function does not contain valid edges
465 with warnings.catch_warnings():
466 warnings.simplefilter("ignore", category=UserWarning)
467 g.pull(0, _message, _reduce, lambda nodes: {"h": nodes.data["h"] * 2})
468
469 new = g.ndata["h"]
470 # 0deg check: fallback to apply
471 assert F.allclose(new[0], 2 * old[0])
472 # non-0deg check: not touched
473 assert F.allclose(new[1], old[1])
474
475
476def test_dynamic_addition():

Callers

nothing calls this directly

Calls 4

set_n_initializerMethod · 0.80
graphMethod · 0.45
ctxMethod · 0.45
pullMethod · 0.45

Tested by

no test coverage detected