| 66 | @pytest.mark.parametrize("indices", [(0, 1, 2, 3), (1, 2, 3, 4)]) |
| 67 | @pytest.mark.parametrize("shape", [None, (3, 5)]) |
| 68 | def test_from_csr(dense_dim, indptr, indices, shape): |
| 69 | val_shape = (len(indices),) |
| 70 | if dense_dim is not None: |
| 71 | val_shape += (dense_dim,) |
| 72 | ctx = F.ctx() |
| 73 | val = torch.randn(val_shape).to(ctx) |
| 74 | indptr = torch.tensor(indptr).to(ctx) |
| 75 | indices = torch.tensor(indices).to(ctx) |
| 76 | mat = from_csr(indptr, indices, val, shape) |
| 77 | |
| 78 | if shape is None: |
| 79 | shape = (indptr.numel() - 1, torch.max(indices).item() + 1) |
| 80 | |
| 81 | assert mat.device == val.device |
| 82 | assert mat.shape == shape |
| 83 | assert mat.nnz == indices.numel() |
| 84 | assert mat.dtype == val.dtype |
| 85 | mat_indptr, mat_indices, value_indices = mat.csr() |
| 86 | mat_val = mat.val if value_indices is None else mat.val[value_indices] |
| 87 | assert torch.allclose(mat_indptr, indptr) |
| 88 | assert torch.allclose(mat_indices, indices) |
| 89 | assert torch.allclose(mat_val, val) |
| 90 | |
| 91 | |
| 92 | @pytest.mark.parametrize("dense_dim", [None, 4]) |