MCPcopy Index your code
hub / github.com/dmlc/dgl / test_sgc_conv

Function test_sgc_conv

tests/python/tensorflow/test_nn.py:395–410  ·  view source on GitHub ↗
(g, idtype, out_dim)

Source from the content-addressed store, hash-verified

393@pytest.mark.parametrize("g", get_cases(["homo"], exclude=["zero-degree"]))
394@pytest.mark.parametrize("out_dim", [1, 2])
395def test_sgc_conv(g, idtype, out_dim):
396 ctx = F.ctx()
397 g = g.astype(idtype).to(ctx)
398 # not cached
399 sgc = nn.SGConv(5, out_dim, 3)
400 feat = F.randn((g.num_nodes(), 5))
401
402 h = sgc(g, feat)
403 assert h.shape[-1] == out_dim
404
405 # cached
406 sgc = nn.SGConv(5, out_dim, 3, True)
407 h_0 = sgc(g, feat)
408 h_1 = sgc(g, feat + 1)
409 assert F.allclose(h_0, h_1)
410 assert h_0.shape[-1] == out_dim
411
412
413@parametrize_idtype

Callers 1

test_nn.pyFile · 0.70

Calls 4

ctxMethod · 0.45
toMethod · 0.45
astypeMethod · 0.45
num_nodesMethod · 0.45

Tested by

no test coverage detected