(g, idtype, aggregator_type)
| 426 | @pytest.mark.parametrize("g", get_cases(["homo", "block-bipartite"])) |
| 427 | @pytest.mark.parametrize("aggregator_type", ["mean", "max", "sum"]) |
| 428 | def test_gin_conv(g, idtype, aggregator_type): |
| 429 | g = g.astype(idtype).to(F.ctx()) |
| 430 | ctx = F.ctx() |
| 431 | gin = nn.GINConv(tf.keras.layers.Dense(12), aggregator_type) |
| 432 | feat = F.randn((g.number_of_src_nodes(), 5)) |
| 433 | h = gin(g, feat) |
| 434 | assert h.shape == (g.number_of_dst_nodes(), 12) |
| 435 | |
| 436 | |
| 437 | @parametrize_idtype |
no test coverage detected