| 179 | |
| 180 | @pytest.mark.parametrize("val_shape", [(), (2,)]) |
| 181 | def test_add_sparse_diag(val_shape): |
| 182 | ctx = F.ctx() |
| 183 | row = torch.tensor([1, 0, 2]).to(ctx) |
| 184 | col = torch.tensor([0, 3, 2]).to(ctx) |
| 185 | val = torch.randn(row.shape + val_shape).to(ctx) |
| 186 | A = dglsp.from_coo(row, col, val) |
| 187 | |
| 188 | shape = (3, 4) |
| 189 | val_shape = (shape[0],) + val_shape |
| 190 | D = dglsp.diag(torch.randn(val_shape).to(ctx), shape=shape) |
| 191 | |
| 192 | sum1 = (A + D).to_dense() |
| 193 | sum2 = (D + A).to_dense() |
| 194 | sum3 = dglsp.add(A, D).to_dense() |
| 195 | sum4 = dglsp.add(D, A).to_dense() |
| 196 | dense_sum = A.to_dense() + D.to_dense() |
| 197 | |
| 198 | assert torch.allclose(dense_sum, sum1) |
| 199 | assert torch.allclose(dense_sum, sum2) |
| 200 | assert torch.allclose(dense_sum, sum3) |
| 201 | assert torch.allclose(dense_sum, sum4) |
| 202 | |
| 203 | |
| 204 | @pytest.mark.parametrize("val_shape", [(), (2,)]) |