MCPcopy
hub / github.com/dmlc/dgl / test_softmax

Function test_softmax

tests/python/pytorch/sparse/test_softmax.py:14–43  ·  view source on GitHub ↗
(val_D, csr, dim)

Source from the content-addressed store, hash-verified

12@pytest.mark.parametrize("csr", [True, False])
13@pytest.mark.parametrize("dim", [0, 1])
14def test_softmax(val_D, csr, dim):
15 dev = F.ctx()
16 row = torch.tensor([0, 0, 1, 1]).to(dev)
17 col = torch.tensor([0, 2, 1, 2]).to(dev)
18 nnz = len(row)
19 if val_D is None:
20 val = torch.randn(nnz).to(dev)
21 else:
22 val = torch.randn(nnz, val_D).to(dev)
23
24 val_sparse = val.clone().requires_grad_()
25 A = from_coo(row, col, val_sparse)
26
27 if csr:
28 # Test CSR
29 A.csr()
30
31 A_max = softmax(A, dim)
32 if dim == 1:
33 g = dgl.graph((col, row), num_nodes=max(A.shape))
34 else:
35 g = dgl.graph((row, col), num_nodes=max(A.shape))
36 val_g = val.clone().requires_grad_()
37 score = dgl.nn.functional.edge_softmax(g, val_g)
38 assert torch.allclose(A_max.val, score, atol=1e-05)
39
40 grad = torch.randn_like(score).to(dev)
41 A_max.val.backward(grad)
42 score.backward(grad)
43 assert torch.allclose(A.val.grad, val_g.grad, atol=1e-05)

Callers

nothing calls this directly

Calls 9

from_cooFunction · 0.90
softmaxFunction · 0.90
csrMethod · 0.80
maxFunction · 0.50
ctxMethod · 0.45
toMethod · 0.45
cloneMethod · 0.45
graphMethod · 0.45
backwardMethod · 0.45

Tested by

no test coverage detected