MCPcopy
hub / github.com/dmlc/dgl / test_spmm

Function test_spmm

tests/python/pytorch/sparse/test_matmul.py:32–57  ·  view source on GitHub ↗
(create_func, shape, nnz, out_dim)

Source from the content-addressed store, hash-verified

30@pytest.mark.parametrize("nnz", [1, 10])
31@pytest.mark.parametrize("out_dim", [None, 10])
32def test_spmm(create_func, shape, nnz, out_dim):
33 dev = F.ctx()
34 A = create_func(shape, nnz, dev)
35 if out_dim is not None:
36 X = torch.randn(shape[1], out_dim, requires_grad=True, device=dev)
37 else:
38 X = torch.randn(shape[1], requires_grad=True, device=dev)
39
40 X = rand_stride(X)
41 sparse_result = matmul(A, X)
42 grad = torch.randn_like(sparse_result)
43 sparse_result.backward(grad)
44
45 adj = sparse_matrix_to_dense(A)
46 XX = clone_detach_and_grad(X)
47 dense_result = torch.matmul(adj, XX)
48 if out_dim is None:
49 dense_result = dense_result.view(-1)
50 dense_result.backward(grad)
51 assert torch.allclose(sparse_result, dense_result, atol=1e-05)
52 assert torch.allclose(X.grad, XX.grad, atol=1e-05)
53 assert torch.allclose(
54 dense_mask(adj.grad, A),
55 sparse_matrix_to_dense(val_like(A, A.val.grad)),
56 atol=1e-05,
57 )
58
59
60@pytest.mark.parametrize("create_func", [rand_coo, rand_csr, rand_csc])

Callers

nothing calls this directly

Calls 8

matmulFunction · 0.90
val_likeFunction · 0.90
rand_strideFunction · 0.85
sparse_matrix_to_denseFunction · 0.85
clone_detach_and_gradFunction · 0.85
dense_maskFunction · 0.85
ctxMethod · 0.45
backwardMethod · 0.45

Tested by

no test coverage detected