MCPcopy Index your code
hub / github.com/dmlc/dgl / test_local_var

Function test_local_var

tests/python/common/function/test_basics.py:521–590  ·  view source on GitHub ↗
(idtype)

Source from the content-addressed store, hash-verified

519
520@parametrize_idtype
521def test_local_var(idtype):
522 g = dgl.graph(([0, 1, 2, 3], [1, 2, 3, 4]), idtype=idtype, device=F.ctx())
523 g.ndata["h"] = F.zeros((g.num_nodes(), 3))
524 g.edata["w"] = F.zeros((g.num_edges(), 4))
525
526 # test override
527 def foo(g):
528 g = g.local_var()
529 g.ndata["h"] = F.ones((g.num_nodes(), 3))
530 g.edata["w"] = F.ones((g.num_edges(), 4))
531
532 foo(g)
533 assert F.allclose(g.ndata["h"], F.zeros((g.num_nodes(), 3)))
534 assert F.allclose(g.edata["w"], F.zeros((g.num_edges(), 4)))
535
536 # test out-place update
537 def foo(g):
538 g = g.local_var()
539 g.nodes[[2, 3]].data["h"] = F.ones((2, 3))
540 g.edges[[2, 3]].data["w"] = F.ones((2, 4))
541
542 foo(g)
543 assert F.allclose(g.ndata["h"], F.zeros((g.num_nodes(), 3)))
544 assert F.allclose(g.edata["w"], F.zeros((g.num_edges(), 4)))
545
546 # test out-place update 2
547 def foo(g):
548 g = g.local_var()
549 g.apply_nodes(lambda nodes: {"h": nodes.data["h"] + 10}, [2, 3])
550 g.apply_edges(lambda edges: {"w": edges.data["w"] + 10}, [2, 3])
551
552 foo(g)
553 assert F.allclose(g.ndata["h"], F.zeros((g.num_nodes(), 3)))
554 assert F.allclose(g.edata["w"], F.zeros((g.num_edges(), 4)))
555
556 # test auto-pop
557 def foo(g):
558 g = g.local_var()
559 g.ndata["hh"] = F.ones((g.num_nodes(), 3))
560 g.edata["ww"] = F.ones((g.num_edges(), 4))
561
562 foo(g)
563 assert "hh" not in g.ndata
564 assert "ww" not in g.edata
565
566 # test initializer1
567 g = dgl.graph(([0, 1], [1, 1]), idtype=idtype, device=F.ctx())
568 g.set_n_initializer(dgl.init.zero_initializer)
569
570 def foo(g):
571 g = g.local_var()
572 g.nodes[0].data["h"] = F.ones((1, 1))
573 assert F.allclose(g.ndata["h"], F.tensor([[1.0], [0.0]]))
574
575 foo(g)
576
577 # test initializer2
578 def foo_e_initializer(shape, dtype, ctx, id_range):

Callers

nothing calls this directly

Calls 7

set_n_initializerMethod · 0.80
set_e_initializerMethod · 0.80
fooFunction · 0.70
graphMethod · 0.45
ctxMethod · 0.45
num_nodesMethod · 0.45
num_edgesMethod · 0.45

Tested by

no test coverage detected