| 289 | @pytest.mark.parametrize("col", [(0, 1, 2, 2), (1, 3, 3, 4)]) |
| 290 | @pytest.mark.parametrize("shape", [None, (5, 5), (5, 6)]) |
| 291 | def test_coo_to_csc(dense_dim, row, col, shape): |
| 292 | val_shape = (len(row),) |
| 293 | if dense_dim is not None: |
| 294 | val_shape += (dense_dim,) |
| 295 | ctx = F.ctx() |
| 296 | val = torch.randn(val_shape).to(ctx) |
| 297 | row = torch.tensor(row).to(ctx) |
| 298 | col = torch.tensor(col).to(ctx) |
| 299 | mat = from_coo(row, col, val, shape) |
| 300 | |
| 301 | if shape is None: |
| 302 | shape = (torch.max(row).item() + 1, torch.max(col).item() + 1) |
| 303 | |
| 304 | mat_indptr, mat_indices, value_indices = mat.csc() |
| 305 | mat_val = mat.val if value_indices is None else mat.val[value_indices] |
| 306 | indptr = torch.zeros(shape[1] + 1).to(ctx) |
| 307 | _scatter_add(indptr, col + 1) |
| 308 | indptr = torch.cumsum(indptr, 0).long() |
| 309 | indices = row |
| 310 | |
| 311 | assert mat.shape == shape |
| 312 | assert mat.nnz == row.numel() |
| 313 | assert mat.dtype == val.dtype |
| 314 | assert torch.allclose(mat_val, val) |
| 315 | assert torch.allclose(mat_indptr, indptr) |
| 316 | assert torch.allclose(mat_indices, indices) |
| 317 | |
| 318 | |
| 319 | @pytest.mark.parametrize("dense_dim", [None, 4]) |