(out_dim)
| 298 | |
| 299 | @pytest.mark.parametrize("out_dim", [1, 20]) |
| 300 | def test_cheb_conv(out_dim): |
| 301 | g = dgl.from_networkx(nx.erdos_renyi_graph(20, 0.3)).to(F.ctx()) |
| 302 | ctx = F.ctx() |
| 303 | |
| 304 | cheb = nn.ChebConv(10, out_dim, 3) # k = 3 |
| 305 | cheb.initialize(ctx=ctx) |
| 306 | print(cheb) |
| 307 | |
| 308 | # test#1: basic |
| 309 | h0 = F.randn((20, 10)) |
| 310 | h1 = cheb(g, h0) |
| 311 | assert h1.shape == (20, out_dim) |
| 312 | |
| 313 | |
| 314 | @parametrize_idtype |
no test coverage detected