(create_func, shape, nnz)
| 61 | @pytest.mark.parametrize("shape", [(2, 7), (5, 2)]) |
| 62 | @pytest.mark.parametrize("nnz", [1, 10]) |
| 63 | def test_bspmm(create_func, shape, nnz): |
| 64 | dev = F.ctx() |
| 65 | A = create_func(shape, nnz, dev, 2) |
| 66 | X = torch.randn(shape[1], 10, 2, requires_grad=True, device=dev) |
| 67 | X = rand_stride(X) |
| 68 | |
| 69 | sparse_result = matmul(A, X) |
| 70 | grad = torch.randn_like(sparse_result) |
| 71 | sparse_result.backward(grad) |
| 72 | |
| 73 | XX = clone_detach_and_grad(X) |
| 74 | torch_A = A.to_dense().clone().detach().requires_grad_() |
| 75 | torch_result = torch_A.permute(2, 0, 1) @ XX.permute(2, 0, 1) |
| 76 | |
| 77 | torch_result.backward(grad.permute(2, 0, 1)) |
| 78 | assert torch.allclose( |
| 79 | sparse_result.permute(2, 0, 1), torch_result, atol=1e-05 |
| 80 | ) |
| 81 | assert torch.allclose(X.grad, XX.grad, atol=1e-05) |
| 82 | assert torch.allclose( |
| 83 | dense_mask(torch_A.grad, A), |
| 84 | sparse_matrix_to_dense(val_like(A, A.val.grad)), |
| 85 | atol=1e-05, |
| 86 | ) |
| 87 | |
| 88 | |
| 89 | @pytest.mark.parametrize("create_func1", [rand_coo, rand_csr, rand_csc]) |
nothing calls this directly
no test coverage detected