MCPcopy
hub / github.com/dmlc/dgl / test_edge_softmax

Function test_edge_softmax

tests/python/common/ops/test_edge_softmax.py:29–56  ·  view source on GitHub ↗
(g, norm_by, shp, idtype)

Source from the content-addressed store, hash-verified

27@pytest.mark.parametrize("shp", edge_softmax_shapes)
28@parametrize_idtype
29def test_edge_softmax(g, norm_by, shp, idtype):
30 g = g.astype(idtype).to(F.ctx())
31 edata = F.tensor(np.random.rand(g.num_edges(), *shp))
32 e1 = F.attach_grad(F.clone(edata))
33
34 with F.record_grad():
35 score1 = edge_softmax(g, e1, norm_by=norm_by)
36 F.backward(F.reduce_sum(score1))
37 grad_edata = F.grad(e1)
38
39 with F.record_grad():
40 e2 = F.attach_grad(F.clone(edata))
41 e2_2d = F.reshape(
42 e2,
43 (g.number_of_src_nodes(), g.number_of_dst_nodes(), *e2.shape[1:]),
44 )
45 if norm_by == "src":
46 score2 = F.softmax(e2_2d, 1)
47 score2 = F.reshape(score2, (-1, *e2.shape[1:]))
48 if norm_by == "dst":
49 score2 = F.softmax(e2_2d, 0)
50 score2 = F.reshape(score2, (-1, *e2.shape[1:]))
51 assert F.allclose(score1, score2)
52 print("forward passed")
53
54 F.backward(F.reduce_sum(score2))
55 assert F.allclose(F.grad(e2), grad_edata)
56 print("backward passed")
57
58
59def create_test_heterograph(idtype):

Callers

nothing calls this directly

Calls 11

edge_softmaxFunction · 0.90
gradMethod · 0.80
number_of_src_nodesMethod · 0.80
number_of_dst_nodesMethod · 0.80
create_test_heterographFunction · 0.70
toMethod · 0.45
astypeMethod · 0.45
ctxMethod · 0.45
num_edgesMethod · 0.45
cloneMethod · 0.45
backwardMethod · 0.45

Tested by

no test coverage detected