| 361 | |
| 362 | @pytest.mark.parametrize("shape", [(3, 5), (5, 5), (5, 4)]) |
| 363 | def test_diag_conversions(shape): |
| 364 | n_rows, n_cols = shape |
| 365 | nnz = min(shape) |
| 366 | ctx = F.ctx() |
| 367 | val = torch.randn(nnz).to(ctx) |
| 368 | D = diag(val, shape) |
| 369 | row, col = D.coo() |
| 370 | assert torch.allclose(row, torch.arange(nnz).to(ctx)) |
| 371 | assert torch.allclose(col, torch.arange(nnz).to(ctx)) |
| 372 | |
| 373 | indptr, indices, _ = D.csr() |
| 374 | exp_indptr = list(range(0, nnz + 1)) + [nnz] * (n_rows - nnz) |
| 375 | assert torch.allclose(indptr, torch.tensor(exp_indptr).to(ctx)) |
| 376 | assert torch.allclose(indices, torch.arange(nnz).to(ctx)) |
| 377 | |
| 378 | indptr, indices, _ = D.csc() |
| 379 | exp_indptr = list(range(0, nnz + 1)) + [nnz] * (n_cols - nnz) |
| 380 | assert torch.allclose(indptr, torch.tensor(exp_indptr).to(ctx)) |
| 381 | assert torch.allclose(indices, torch.arange(nnz).to(ctx)) |
| 382 | |
| 383 | |
| 384 | @pytest.mark.parametrize("val_shape", [(3), (3, 2)]) |