MCPcopy
hub / github.com/dmlc/dgl / test_softmax

Function test_softmax

tests/python/common/test_readout.py:193–212  ·  view source on GitHub ↗
(g, idtype)

Source from the content-addressed store, hash-verified

191@parametrize_idtype
192@pytest.mark.parametrize("g", get_cases(["homo"], exclude=["dglgraph"]))
193def test_softmax(g, idtype):
194 g = g.astype(idtype).to(F.ctx())
195 g.ndata["h"] = F.randn((g.num_nodes(), 3))
196 g.edata["h"] = F.randn((g.num_edges(), 2))
197
198 # Test.1: node readout
199 x = dgl.softmax_nodes(g, "h")
200 subg = dgl.unbatch(g)
201 subx = []
202 for sg in subg:
203 subx.append(F.softmax(sg.ndata["h"], dim=0))
204 assert F.allclose(x, F.cat(subx, dim=0))
205
206 # Test.2: edge readout
207 x = dgl.softmax_edges(g, "h")
208 subg = dgl.unbatch(g)
209 subx = []
210 for sg in subg:
211 subx.append(F.softmax(sg.edata["h"], dim=0))
212 assert F.allclose(x, F.cat(subx, dim=0))
213
214
215@parametrize_idtype

Callers

nothing calls this directly

Calls 6

appendMethod · 0.80
toMethod · 0.45
astypeMethod · 0.45
ctxMethod · 0.45
num_nodesMethod · 0.45
num_edgesMethod · 0.45

Tested by

no test coverage detected