MCPcopy
hub / github.com/dmlc/dgl / test_bspmm

Function test_bspmm

tests/python/pytorch/sparse/test_matmul.py:63–86  ·  view source on GitHub ↗
(create_func, shape, nnz)

Source from the content-addressed store, hash-verified

61@pytest.mark.parametrize("shape", [(2, 7), (5, 2)])
62@pytest.mark.parametrize("nnz", [1, 10])
63def test_bspmm(create_func, shape, nnz):
64 dev = F.ctx()
65 A = create_func(shape, nnz, dev, 2)
66 X = torch.randn(shape[1], 10, 2, requires_grad=True, device=dev)
67 X = rand_stride(X)
68
69 sparse_result = matmul(A, X)
70 grad = torch.randn_like(sparse_result)
71 sparse_result.backward(grad)
72
73 XX = clone_detach_and_grad(X)
74 torch_A = A.to_dense().clone().detach().requires_grad_()
75 torch_result = torch_A.permute(2, 0, 1) @ XX.permute(2, 0, 1)
76
77 torch_result.backward(grad.permute(2, 0, 1))
78 assert torch.allclose(
79 sparse_result.permute(2, 0, 1), torch_result, atol=1e-05
80 )
81 assert torch.allclose(X.grad, XX.grad, atol=1e-05)
82 assert torch.allclose(
83 dense_mask(torch_A.grad, A),
84 sparse_matrix_to_dense(val_like(A, A.val.grad)),
85 atol=1e-05,
86 )
87
88
89@pytest.mark.parametrize("create_func1", [rand_coo, rand_csr, rand_csc])

Callers

nothing calls this directly

Calls 10

matmulFunction · 0.90
val_likeFunction · 0.90
rand_strideFunction · 0.85
clone_detach_and_gradFunction · 0.85
dense_maskFunction · 0.85
sparse_matrix_to_denseFunction · 0.85
to_denseMethod · 0.80
ctxMethod · 0.45
backwardMethod · 0.45
cloneMethod · 0.45

Tested by

no test coverage detected