| 94 | @pytest.mark.parametrize("indices", [(0, 1, 2, 3), (1, 2, 3, 4)]) |
| 95 | @pytest.mark.parametrize("shape", [None, (5, 3)]) |
| 96 | def test_from_csc(dense_dim, indptr, indices, shape): |
| 97 | val_shape = (len(indices),) |
| 98 | if dense_dim is not None: |
| 99 | val_shape += (dense_dim,) |
| 100 | ctx = F.ctx() |
| 101 | val = torch.randn(val_shape).to(ctx) |
| 102 | indptr = torch.tensor(indptr).to(ctx) |
| 103 | indices = torch.tensor(indices).to(ctx) |
| 104 | mat = from_csc(indptr, indices, val, shape) |
| 105 | |
| 106 | if shape is None: |
| 107 | shape = (torch.max(indices).item() + 1, indptr.numel() - 1) |
| 108 | |
| 109 | assert mat.device == val.device |
| 110 | assert mat.shape == shape |
| 111 | assert mat.nnz == indices.numel() |
| 112 | assert mat.dtype == val.dtype |
| 113 | mat_indptr, mat_indices, value_indices = mat.csc() |
| 114 | mat_val = mat.val if value_indices is None else mat.val[value_indices] |
| 115 | assert torch.allclose(mat_indptr, indptr) |
| 116 | assert torch.allclose(mat_indices, indices) |
| 117 | assert torch.allclose(mat_val, val) |
| 118 | |
| 119 | |
| 120 | @pytest.mark.parametrize("val_shape", [(3), (3, 2)]) |