(g, idtype, out_dim)
| 393 | @pytest.mark.parametrize("g", get_cases(["homo"], exclude=["zero-degree"])) |
| 394 | @pytest.mark.parametrize("out_dim", [1, 2]) |
| 395 | def test_sgc_conv(g, idtype, out_dim): |
| 396 | ctx = F.ctx() |
| 397 | g = g.astype(idtype).to(ctx) |
| 398 | # not cached |
| 399 | sgc = nn.SGConv(5, out_dim, 3) |
| 400 | feat = F.randn((g.num_nodes(), 5)) |
| 401 | |
| 402 | h = sgc(g, feat) |
| 403 | assert h.shape[-1] == out_dim |
| 404 | |
| 405 | # cached |
| 406 | sgc = nn.SGConv(5, out_dim, 3, True) |
| 407 | h_0 = sgc(g, feat) |
| 408 | h_1 = sgc(g, feat + 1) |
| 409 | assert F.allclose(h_0, h_1) |
| 410 | assert h_0.shape[-1] == out_dim |
| 411 | |
| 412 | |
| 413 | @parametrize_idtype |
no test coverage detected