| 156 | @pytest.mark.parametrize("val_shape", [(), (2,)]) |
| 157 | @pytest.mark.parametrize("opname", ["add", "sub"]) |
| 158 | def test_addsub_diag(val_shape, opname): |
| 159 | op = getattr(operator, opname) |
| 160 | func = getattr(dglsp, opname) |
| 161 | ctx = F.ctx() |
| 162 | shape = (3, 4) |
| 163 | val_shape = (shape[0],) + val_shape |
| 164 | D1 = dglsp.diag(torch.randn(val_shape).to(ctx), shape=shape) |
| 165 | D2 = dglsp.diag(torch.randn(val_shape).to(ctx), shape=shape) |
| 166 | |
| 167 | C1 = op(D1, D2).to_dense() |
| 168 | C2 = func(D1, D2).to_dense() |
| 169 | dense_C = op(D1.to_dense(), D2.to_dense()) |
| 170 | |
| 171 | assert torch.allclose(dense_C, C1) |
| 172 | assert torch.allclose(dense_C, C2) |
| 173 | |
| 174 | with pytest.raises(TypeError): |
| 175 | op(D1, 2) |
| 176 | with pytest.raises(TypeError): |
| 177 | op(2, D1) |
| 178 | |
| 179 | |
| 180 | @pytest.mark.parametrize("val_shape", [(), (2,)]) |