MCPcopy
hub / github.com/dmlc/dgl / test_weighted_reduce_readout

Function test_weighted_reduce_readout

tests/python/common/test_readout.py:78–121  ·  view source on GitHub ↗
(g, idtype, reducer)

Source from the content-addressed store, hash-verified

76@pytest.mark.parametrize("g", get_cases(["homo"], exclude=["dglgraph"]))
77@pytest.mark.parametrize("reducer", ["sum", "max", "mean"])
78def test_weighted_reduce_readout(g, idtype, reducer):
79 g = g.astype(idtype).to(F.ctx())
80 g.ndata["h"] = F.randn((g.num_nodes(), 3))
81 g.ndata["w"] = F.randn((g.num_nodes(), 1))
82 g.edata["h"] = F.randn((g.num_edges(), 2))
83 g.edata["w"] = F.randn((g.num_edges(), 1))
84
85 # Test.1: node readout
86 x = dgl.readout_nodes(g, "h", "w", op=reducer)
87 # check correctness
88 subg = dgl.unbatch(g)
89 subx = []
90 for sg in subg:
91 sx = dgl.readout_nodes(sg, "h", "w", op=reducer)
92 subx.append(sx)
93 assert F.allclose(x, F.cat(subx, dim=0))
94
95 x = getattr(dgl, "{}_nodes".format(reducer))(g, "h", "w")
96 # check correctness
97 subg = dgl.unbatch(g)
98 subx = []
99 for sg in subg:
100 sx = getattr(dgl, "{}_nodes".format(reducer))(sg, "h", "w")
101 subx.append(sx)
102 assert F.allclose(x, F.cat(subx, dim=0))
103
104 # Test.2: edge readout
105 x = dgl.readout_edges(g, "h", "w", op=reducer)
106 # check correctness
107 subg = dgl.unbatch(g)
108 subx = []
109 for sg in subg:
110 sx = dgl.readout_edges(sg, "h", "w", op=reducer)
111 subx.append(sx)
112 assert F.allclose(x, F.cat(subx, dim=0))
113
114 x = getattr(dgl, "{}_edges".format(reducer))(g, "h", "w")
115 # check correctness
116 subg = dgl.unbatch(g)
117 subx = []
118 for sg in subg:
119 sx = getattr(dgl, "{}_edges".format(reducer))(sg, "h", "w")
120 subx.append(sx)
121 assert F.allclose(x, F.cat(subx, dim=0))
122
123
124@parametrize_idtype

Callers

nothing calls this directly

Calls 7

appendMethod · 0.80
formatMethod · 0.80
toMethod · 0.45
astypeMethod · 0.45
ctxMethod · 0.45
num_nodesMethod · 0.45
num_edgesMethod · 0.45

Tested by

no test coverage detected