MCPcopy Index your code
hub / github.com/dmlc/dgl / test_gin_conv

Function test_gin_conv

tests/python/tensorflow/test_nn.py:428–434  ·  view source on GitHub ↗
(g, idtype, aggregator_type)

Source from the content-addressed store, hash-verified

426@pytest.mark.parametrize("g", get_cases(["homo", "block-bipartite"]))
427@pytest.mark.parametrize("aggregator_type", ["mean", "max", "sum"])
428def test_gin_conv(g, idtype, aggregator_type):
429 g = g.astype(idtype).to(F.ctx())
430 ctx = F.ctx()
431 gin = nn.GINConv(tf.keras.layers.Dense(12), aggregator_type)
432 feat = F.randn((g.number_of_src_nodes(), 5))
433 h = gin(g, feat)
434 assert h.shape == (g.number_of_dst_nodes(), 12)
435
436
437@parametrize_idtype

Callers 1

test_nn.pyFile · 0.70

Calls 5

number_of_src_nodesMethod · 0.80
number_of_dst_nodesMethod · 0.80
toMethod · 0.45
astypeMethod · 0.45
ctxMethod · 0.45

Tested by

no test coverage detected