MCPcopy
hub / github.com/dmlc/dgl / test_sequential

Function test_sequential

tests/python/mxnet/test_nn.py:730–784  ·  view source on GitHub ↗
()

Source from the content-addressed store, hash-verified

728
729
730def test_sequential():
731 ctx = F.ctx()
732
733 # test single graph
734 class ExampleLayer(gluon.nn.Block):
735 def __init__(self, **kwargs):
736 super().__init__(**kwargs)
737
738 def forward(self, graph, n_feat, e_feat):
739 graph = graph.local_var()
740 graph.ndata["h"] = n_feat
741 graph.update_all(fn.copy_u("h", "m"), fn.sum("m", "h"))
742 n_feat += graph.ndata["h"]
743 graph.apply_edges(fn.u_add_v("h", "h", "e"))
744 e_feat += graph.edata["e"]
745 return n_feat, e_feat
746
747 g = dgl.graph(([], [])).to(F.ctx())
748 g.add_nodes(3)
749 g.add_edges([0, 1, 2, 0, 1, 2, 0, 1, 2], [0, 0, 0, 1, 1, 1, 2, 2, 2])
750 net = nn.Sequential()
751 net.add(ExampleLayer())
752 net.add(ExampleLayer())
753 net.add(ExampleLayer())
754 net.initialize(ctx=ctx)
755 n_feat = F.randn((3, 4))
756 e_feat = F.randn((9, 4))
757 n_feat, e_feat = net(g, n_feat, e_feat)
758 assert n_feat.shape == (3, 4)
759 assert e_feat.shape == (9, 4)
760
761 # test multiple graphs
762 class ExampleLayer(gluon.nn.Block):
763 def __init__(self, **kwargs):
764 super().__init__(**kwargs)
765
766 def forward(self, graph, n_feat):
767 graph = graph.local_var()
768 graph.ndata["h"] = n_feat
769 graph.update_all(fn.copy_u("h", "m"), fn.sum("m", "h"))
770 n_feat += graph.ndata["h"]
771 return n_feat.reshape(graph.num_nodes() // 2, 2, -1).sum(1)
772
773 g1 = dgl.from_networkx(nx.erdos_renyi_graph(32, 0.05)).to(F.ctx())
774 g2 = dgl.from_networkx(nx.erdos_renyi_graph(16, 0.2)).to(F.ctx())
775 g3 = dgl.from_networkx(nx.erdos_renyi_graph(8, 0.8)).to(F.ctx())
776
777 net = nn.Sequential()
778 net.add(ExampleLayer())
779 net.add(ExampleLayer())
780 net.add(ExampleLayer())
781 net.initialize(ctx=ctx)
782 n_feat = F.randn((32, 4))
783 n_feat = net([g1, g2, g3], n_feat)
784 assert n_feat.shape == (4, 4)
785
786
787def myagg(alist, dsttype):

Callers 1

test_nn.pyFile · 0.70

Calls 6

ExampleLayerClass · 0.70
ctxMethod · 0.45
toMethod · 0.45
graphMethod · 0.45
add_nodesMethod · 0.45
add_edgesMethod · 0.45

Tested by

no test coverage detected