| 127 | @pytest.mark.parametrize("val_shape", [(), (2,)]) |
| 128 | @pytest.mark.parametrize("opname", ["add", "sub"]) |
| 129 | def test_addsub_csc(val_shape, opname): |
| 130 | op = getattr(operator, opname) |
| 131 | func = getattr(dglsp, opname) |
| 132 | ctx = F.ctx() |
| 133 | indptr = torch.tensor([0, 1, 1, 2, 3]).to(ctx) |
| 134 | indices = torch.tensor([1, 2, 0]).to(ctx) |
| 135 | val = torch.randn(indices.shape + val_shape).to(ctx) |
| 136 | A = dglsp.from_csc(indptr, indices, val) |
| 137 | |
| 138 | indptr = torch.tensor([0, 1, 1, 2, 2]).to(ctx) |
| 139 | indices = torch.tensor([1, 0]).to(ctx) |
| 140 | val = torch.randn(indices.shape + val_shape).to(ctx) |
| 141 | B = dglsp.from_csc(indptr, indices, val, shape=A.shape) |
| 142 | |
| 143 | C1 = op(A, B).to_dense() |
| 144 | C2 = func(A, B).to_dense() |
| 145 | dense_C = op(A.to_dense(), B.to_dense()) |
| 146 | |
| 147 | assert torch.allclose(dense_C, C1) |
| 148 | assert torch.allclose(dense_C, C2) |
| 149 | |
| 150 | with pytest.raises(TypeError): |
| 151 | op(A, 2) |
| 152 | with pytest.raises(TypeError): |
| 153 | op(2, A) |
| 154 | |
| 155 | |
| 156 | @pytest.mark.parametrize("val_shape", [(), (2,)]) |