(idtype, g, aggre_type, out_dim)
| 230 | @pytest.mark.parametrize("aggre_type", ["mean", "pool", "gcn"]) |
| 231 | @pytest.mark.parametrize("out_dim", [1, 10]) |
| 232 | def test_sage_conv(idtype, g, aggre_type, out_dim): |
| 233 | g = g.astype(idtype).to(F.ctx()) |
| 234 | ctx = F.ctx() |
| 235 | sage = nn.SAGEConv(5, out_dim, aggre_type) |
| 236 | feat = F.randn((g.number_of_src_nodes(), 5)) |
| 237 | sage.initialize(ctx=ctx) |
| 238 | h = sage(g, feat) |
| 239 | assert h.shape[-1] == out_dim |
| 240 | |
| 241 | |
| 242 | @parametrize_idtype |
no test coverage detected