(g, norm_by, shp, idtype)
| 27 | @pytest.mark.parametrize("shp", edge_softmax_shapes) |
| 28 | @parametrize_idtype |
| 29 | def test_edge_softmax(g, norm_by, shp, idtype): |
| 30 | g = g.astype(idtype).to(F.ctx()) |
| 31 | edata = F.tensor(np.random.rand(g.num_edges(), *shp)) |
| 32 | e1 = F.attach_grad(F.clone(edata)) |
| 33 | |
| 34 | with F.record_grad(): |
| 35 | score1 = edge_softmax(g, e1, norm_by=norm_by) |
| 36 | F.backward(F.reduce_sum(score1)) |
| 37 | grad_edata = F.grad(e1) |
| 38 | |
| 39 | with F.record_grad(): |
| 40 | e2 = F.attach_grad(F.clone(edata)) |
| 41 | e2_2d = F.reshape( |
| 42 | e2, |
| 43 | (g.number_of_src_nodes(), g.number_of_dst_nodes(), *e2.shape[1:]), |
| 44 | ) |
| 45 | if norm_by == "src": |
| 46 | score2 = F.softmax(e2_2d, 1) |
| 47 | score2 = F.reshape(score2, (-1, *e2.shape[1:])) |
| 48 | if norm_by == "dst": |
| 49 | score2 = F.softmax(e2_2d, 0) |
| 50 | score2 = F.reshape(score2, (-1, *e2.shape[1:])) |
| 51 | assert F.allclose(score1, score2) |
| 52 | print("forward passed") |
| 53 | |
| 54 | F.backward(F.reduce_sum(score2)) |
| 55 | assert F.allclose(F.grad(e2), grad_edata) |
| 56 | print("backward passed") |
| 57 | |
| 58 | |
| 59 | def create_test_heterograph(idtype): |
nothing calls this directly
no test coverage detected