()
| 728 | |
| 729 | |
| 730 | def test_sequential(): |
| 731 | ctx = F.ctx() |
| 732 | |
| 733 | # test single graph |
| 734 | class ExampleLayer(gluon.nn.Block): |
| 735 | def __init__(self, **kwargs): |
| 736 | super().__init__(**kwargs) |
| 737 | |
| 738 | def forward(self, graph, n_feat, e_feat): |
| 739 | graph = graph.local_var() |
| 740 | graph.ndata["h"] = n_feat |
| 741 | graph.update_all(fn.copy_u("h", "m"), fn.sum("m", "h")) |
| 742 | n_feat += graph.ndata["h"] |
| 743 | graph.apply_edges(fn.u_add_v("h", "h", "e")) |
| 744 | e_feat += graph.edata["e"] |
| 745 | return n_feat, e_feat |
| 746 | |
| 747 | g = dgl.graph(([], [])).to(F.ctx()) |
| 748 | g.add_nodes(3) |
| 749 | g.add_edges([0, 1, 2, 0, 1, 2, 0, 1, 2], [0, 0, 0, 1, 1, 1, 2, 2, 2]) |
| 750 | net = nn.Sequential() |
| 751 | net.add(ExampleLayer()) |
| 752 | net.add(ExampleLayer()) |
| 753 | net.add(ExampleLayer()) |
| 754 | net.initialize(ctx=ctx) |
| 755 | n_feat = F.randn((3, 4)) |
| 756 | e_feat = F.randn((9, 4)) |
| 757 | n_feat, e_feat = net(g, n_feat, e_feat) |
| 758 | assert n_feat.shape == (3, 4) |
| 759 | assert e_feat.shape == (9, 4) |
| 760 | |
| 761 | # test multiple graphs |
| 762 | class ExampleLayer(gluon.nn.Block): |
| 763 | def __init__(self, **kwargs): |
| 764 | super().__init__(**kwargs) |
| 765 | |
| 766 | def forward(self, graph, n_feat): |
| 767 | graph = graph.local_var() |
| 768 | graph.ndata["h"] = n_feat |
| 769 | graph.update_all(fn.copy_u("h", "m"), fn.sum("m", "h")) |
| 770 | n_feat += graph.ndata["h"] |
| 771 | return n_feat.reshape(graph.num_nodes() // 2, 2, -1).sum(1) |
| 772 | |
| 773 | g1 = dgl.from_networkx(nx.erdos_renyi_graph(32, 0.05)).to(F.ctx()) |
| 774 | g2 = dgl.from_networkx(nx.erdos_renyi_graph(16, 0.2)).to(F.ctx()) |
| 775 | g3 = dgl.from_networkx(nx.erdos_renyi_graph(8, 0.8)).to(F.ctx()) |
| 776 | |
| 777 | net = nn.Sequential() |
| 778 | net.add(ExampleLayer()) |
| 779 | net.add(ExampleLayer()) |
| 780 | net.add(ExampleLayer()) |
| 781 | net.initialize(ctx=ctx) |
| 782 | n_feat = F.randn((32, 4)) |
| 783 | n_feat = net([g1, g2, g3], n_feat) |
| 784 | assert n_feat.shape == (4, 4) |
| 785 | |
| 786 | |
| 787 | def myagg(alist, dsttype): |
no test coverage detected