MCPcopy
hub / github.com/dmlc/dgl / test_reduce_readout

Function test_reduce_readout

tests/python/common/test_readout.py:31–72  ·  view source on GitHub ↗
(g, idtype, reducer)

Source from the content-addressed store, hash-verified

29@pytest.mark.parametrize("g", get_cases(["homo"], exclude=["dglgraph"]))
30@pytest.mark.parametrize("reducer", ["sum", "max", "mean"])
31def test_reduce_readout(g, idtype, reducer):
32 g = g.astype(idtype).to(F.ctx())
33 g.ndata["h"] = F.randn((g.num_nodes(), 3))
34 g.edata["h"] = F.randn((g.num_edges(), 2))
35
36 # Test.1: node readout
37 x = dgl.readout_nodes(g, "h", op=reducer)
38 # check correctness
39 subg = dgl.unbatch(g)
40 subx = []
41 for sg in subg:
42 sx = dgl.readout_nodes(sg, "h", op=reducer)
43 subx.append(sx)
44 assert F.allclose(x, F.cat(subx, dim=0))
45
46 x = getattr(dgl, "{}_nodes".format(reducer))(g, "h")
47 # check correctness
48 subg = dgl.unbatch(g)
49 subx = []
50 for sg in subg:
51 sx = getattr(dgl, "{}_nodes".format(reducer))(sg, "h")
52 subx.append(sx)
53 assert F.allclose(x, F.cat(subx, dim=0))
54
55 # Test.2: edge readout
56 x = dgl.readout_edges(g, "h", op=reducer)
57 # check correctness
58 subg = dgl.unbatch(g)
59 subx = []
60 for sg in subg:
61 sx = dgl.readout_edges(sg, "h", op=reducer)
62 subx.append(sx)
63 assert F.allclose(x, F.cat(subx, dim=0))
64
65 x = getattr(dgl, "{}_edges".format(reducer))(g, "h")
66 # check correctness
67 subg = dgl.unbatch(g)
68 subx = []
69 for sg in subg:
70 sx = getattr(dgl, "{}_edges".format(reducer))(sg, "h")
71 subx.append(sx)
72 assert F.allclose(x, F.cat(subx, dim=0))
73
74
75@parametrize_idtype

Callers

nothing calls this directly

Calls 7

appendMethod · 0.80
formatMethod · 0.80
toMethod · 0.45
astypeMethod · 0.45
ctxMethod · 0.45
num_nodesMethod · 0.45
num_edgesMethod · 0.45

Tested by

no test coverage detected