(agg, idtype)
| 794 | @parametrize_idtype |
| 795 | @pytest.mark.parametrize("agg", ["sum", "max", "min", "mean", "stack", myagg]) |
| 796 | def test_hetero_conv(agg, idtype): |
| 797 | g = dgl.heterograph( |
| 798 | { |
| 799 | ("user", "follows", "user"): ([0, 0, 2, 1], [1, 2, 1, 3]), |
| 800 | ("user", "plays", "game"): ([0, 0, 0, 1, 2], [0, 2, 3, 0, 2]), |
| 801 | ("store", "sells", "game"): ([0, 0, 1, 1], [0, 3, 1, 2]), |
| 802 | }, |
| 803 | idtype=idtype, |
| 804 | device=F.ctx(), |
| 805 | ) |
| 806 | conv = nn.HeteroGraphConv( |
| 807 | { |
| 808 | "follows": nn.GraphConv(2, 3, allow_zero_in_degree=True), |
| 809 | "plays": nn.GraphConv(2, 4, allow_zero_in_degree=True), |
| 810 | "sells": nn.GraphConv(3, 4, allow_zero_in_degree=True), |
| 811 | }, |
| 812 | agg, |
| 813 | ) |
| 814 | conv.initialize(ctx=F.ctx()) |
| 815 | print(conv) |
| 816 | uf = F.randn((4, 2)) |
| 817 | gf = F.randn((4, 4)) |
| 818 | sf = F.randn((2, 3)) |
| 819 | |
| 820 | h = conv(g, {"user": uf, "store": sf, "game": gf}) |
| 821 | assert set(h.keys()) == {"user", "game"} |
| 822 | if agg != "stack": |
| 823 | assert h["user"].shape == (4, 3) |
| 824 | assert h["game"].shape == (4, 4) |
| 825 | else: |
| 826 | assert h["user"].shape == (4, 1, 3) |
| 827 | assert h["game"].shape == (4, 2, 4) |
| 828 | |
| 829 | block = dgl.to_block( |
| 830 | g.to(F.cpu()), {"user": [0, 1, 2, 3], "game": [0, 1, 2, 3], "store": []} |
| 831 | ).to(F.ctx()) |
| 832 | h = conv( |
| 833 | block, |
| 834 | ( |
| 835 | {"user": uf, "game": gf, "store": sf}, |
| 836 | {"user": uf, "game": gf, "store": sf[0:0]}, |
| 837 | ), |
| 838 | ) |
| 839 | assert set(h.keys()) == {"user", "game"} |
| 840 | if agg != "stack": |
| 841 | assert h["user"].shape == (4, 3) |
| 842 | assert h["game"].shape == (4, 4) |
| 843 | else: |
| 844 | assert h["user"].shape == (4, 1, 3) |
| 845 | assert h["game"].shape == (4, 2, 4) |
| 846 | |
| 847 | h = conv(block, {"user": uf, "game": gf, "store": sf}) |
| 848 | assert set(h.keys()) == {"user", "game"} |
| 849 | if agg != "stack": |
| 850 | assert h["user"].shape == (4, 3) |
| 851 | assert h["game"].shape == (4, 4) |
| 852 | else: |
| 853 | assert h["user"].shape == (4, 1, 3) |
no test coverage detected