(create_func, dim, index)
| 17 | @pytest.mark.parametrize("dim", [0, 1]) |
| 18 | @pytest.mark.parametrize("index", [None, (1, 3), (4, 0, 2)]) |
| 19 | def test_compact(create_func, dim, index): |
| 20 | ctx = F.ctx() |
| 21 | shape = (5, 5) |
| 22 | ans_idx = [] |
| 23 | if index is not None: |
| 24 | ans_idx = list(dict.fromkeys(index)) |
| 25 | index = torch.tensor(index).to(ctx) |
| 26 | |
| 27 | A = create_func(shape, 8, ctx) |
| 28 | |
| 29 | A_compact, ret_id = A.compact(dim, index) |
| 30 | A_compact_dense = sparse_matrix_to_dense(A_compact) |
| 31 | |
| 32 | A_dense = sparse_matrix_to_dense(A) |
| 33 | |
| 34 | for i in range(shape[dim]): |
| 35 | if dim == 0: |
| 36 | row = list(A_dense[i, :].nonzero().reshape(-1)) |
| 37 | else: |
| 38 | row = list(A_dense[:, i].nonzero().reshape(-1)) |
| 39 | if (i not in list(ans_idx)) and len(row) > 0: |
| 40 | ans_idx.append(i) |
| 41 | if len(ans_idx): |
| 42 | ans_idx = torch.tensor(ans_idx).to(ctx) |
| 43 | A_dense_select = sparse_matrix_to_dense(A.index_select(dim, ans_idx)) |
| 44 | |
| 45 | assert A_compact_dense.shape == A_dense_select.shape |
| 46 | assert torch.allclose(A_compact_dense, A_dense_select) |
| 47 | assert torch.allclose(ans_idx, ret_id) |
nothing calls this directly
no test coverage detected