| 12 | @pytest.mark.parametrize("csr", [True, False]) |
| 13 | @pytest.mark.parametrize("dim", [0, 1]) |
| 14 | def test_softmax(val_D, csr, dim): |
| 15 | dev = F.ctx() |
| 16 | row = torch.tensor([0, 0, 1, 1]).to(dev) |
| 17 | col = torch.tensor([0, 2, 1, 2]).to(dev) |
| 18 | nnz = len(row) |
| 19 | if val_D is None: |
| 20 | val = torch.randn(nnz).to(dev) |
| 21 | else: |
| 22 | val = torch.randn(nnz, val_D).to(dev) |
| 23 | |
| 24 | val_sparse = val.clone().requires_grad_() |
| 25 | A = from_coo(row, col, val_sparse) |
| 26 | |
| 27 | if csr: |
| 28 | # Test CSR |
| 29 | A.csr() |
| 30 | |
| 31 | A_max = softmax(A, dim) |
| 32 | if dim == 1: |
| 33 | g = dgl.graph((col, row), num_nodes=max(A.shape)) |
| 34 | else: |
| 35 | g = dgl.graph((row, col), num_nodes=max(A.shape)) |
| 36 | val_g = val.clone().requires_grad_() |
| 37 | score = dgl.nn.functional.edge_softmax(g, val_g) |
| 38 | assert torch.allclose(A_max.val, score, atol=1e-05) |
| 39 | |
| 40 | grad = torch.randn_like(score).to(dev) |
| 41 | A_max.val.backward(grad) |
| 42 | score.backward(grad) |
| 43 | assert torch.allclose(A.val.grad, val_g.grad, atol=1e-05) |