MCPcopy
hub / github.com/dmlc/dgl / test_hetero_conv

Function test_hetero_conv

tests/python/mxnet/test_nn.py:796–899  ·  view source on GitHub ↗
(agg, idtype)

Source from the content-addressed store, hash-verified

794@parametrize_idtype
795@pytest.mark.parametrize("agg", ["sum", "max", "min", "mean", "stack", myagg])
796def test_hetero_conv(agg, idtype):
797 g = dgl.heterograph(
798 {
799 ("user", "follows", "user"): ([0, 0, 2, 1], [1, 2, 1, 3]),
800 ("user", "plays", "game"): ([0, 0, 0, 1, 2], [0, 2, 3, 0, 2]),
801 ("store", "sells", "game"): ([0, 0, 1, 1], [0, 3, 1, 2]),
802 },
803 idtype=idtype,
804 device=F.ctx(),
805 )
806 conv = nn.HeteroGraphConv(
807 {
808 "follows": nn.GraphConv(2, 3, allow_zero_in_degree=True),
809 "plays": nn.GraphConv(2, 4, allow_zero_in_degree=True),
810 "sells": nn.GraphConv(3, 4, allow_zero_in_degree=True),
811 },
812 agg,
813 )
814 conv.initialize(ctx=F.ctx())
815 print(conv)
816 uf = F.randn((4, 2))
817 gf = F.randn((4, 4))
818 sf = F.randn((2, 3))
819
820 h = conv(g, {"user": uf, "store": sf, "game": gf})
821 assert set(h.keys()) == {"user", "game"}
822 if agg != "stack":
823 assert h["user"].shape == (4, 3)
824 assert h["game"].shape == (4, 4)
825 else:
826 assert h["user"].shape == (4, 1, 3)
827 assert h["game"].shape == (4, 2, 4)
828
829 block = dgl.to_block(
830 g.to(F.cpu()), {"user": [0, 1, 2, 3], "game": [0, 1, 2, 3], "store": []}
831 ).to(F.ctx())
832 h = conv(
833 block,
834 (
835 {"user": uf, "game": gf, "store": sf},
836 {"user": uf, "game": gf, "store": sf[0:0]},
837 ),
838 )
839 assert set(h.keys()) == {"user", "game"}
840 if agg != "stack":
841 assert h["user"].shape == (4, 3)
842 assert h["game"].shape == (4, 4)
843 else:
844 assert h["user"].shape == (4, 1, 3)
845 assert h["game"].shape == (4, 2, 4)
846
847 h = conv(block, {"user": uf, "game": gf, "store": sf})
848 assert set(h.keys()) == {"user", "game"}
849 if agg != "stack":
850 assert h["user"].shape == (4, 3)
851 assert h["game"].shape == (4, 4)
852 else:
853 assert h["user"].shape == (4, 1, 3)

Callers 1

test_nn.pyFile · 0.70

Calls 8

remove_edgesMethod · 0.80
MyModClass · 0.70
ctxMethod · 0.45
keysMethod · 0.45
toMethod · 0.45
cpuMethod · 0.45
edgesMethod · 0.45
num_edgesMethod · 0.45

Tested by

no test coverage detected