| 98 | @pytest.mark.parametrize("val_shape", [(), (2,)]) |
| 99 | @pytest.mark.parametrize("opname", ["add", "sub"]) |
| 100 | def test_addsub_csr(val_shape, opname): |
| 101 | op = getattr(operator, opname) |
| 102 | func = getattr(dglsp, opname) |
| 103 | ctx = F.ctx() |
| 104 | indptr = torch.tensor([0, 1, 2, 3]).to(ctx) |
| 105 | indices = torch.tensor([3, 0, 2]).to(ctx) |
| 106 | val = torch.randn(indices.shape + val_shape).to(ctx) |
| 107 | A = dglsp.from_csr(indptr, indices, val) |
| 108 | |
| 109 | indptr = torch.tensor([0, 1, 2, 2]).to(ctx) |
| 110 | indices = torch.tensor([2, 0]).to(ctx) |
| 111 | val = torch.randn(indices.shape + val_shape).to(ctx) |
| 112 | B = dglsp.from_csr(indptr, indices, val, shape=A.shape) |
| 113 | |
| 114 | C1 = op(A, B).to_dense() |
| 115 | C2 = func(A, B).to_dense() |
| 116 | dense_C = op(A.to_dense(), B.to_dense()) |
| 117 | |
| 118 | assert torch.allclose(dense_C, C1) |
| 119 | assert torch.allclose(dense_C, C2) |
| 120 | |
| 121 | with pytest.raises(TypeError): |
| 122 | op(A, 2) |
| 123 | with pytest.raises(TypeError): |
| 124 | op(2, A) |
| 125 | |
| 126 | |
| 127 | @pytest.mark.parametrize("val_shape", [(), (2,)]) |