(create_func, shape, dense_dim, select_dim, index)
| 466 | @pytest.mark.parametrize("select_dim", [0, 1]) |
| 467 | @pytest.mark.parametrize("index", [(0, 1, 3), (1, 2)]) |
| 468 | def test_index_select(create_func, shape, dense_dim, select_dim, index): |
| 469 | ctx = F.ctx() |
| 470 | A = create_func(shape, 20, ctx, dense_dim) |
| 471 | index = torch.tensor(index).to(ctx) |
| 472 | A_select = A.index_select(select_dim, index) |
| 473 | |
| 474 | dense = sparse_matrix_to_dense(A) |
| 475 | dense_select = torch.index_select(dense, select_dim, index) |
| 476 | |
| 477 | A_select_to_dense = sparse_matrix_to_dense(A_select) |
| 478 | |
| 479 | assert A_select_to_dense.shape == dense_select.shape |
| 480 | assert torch.allclose(A_select_to_dense, dense_select) |
| 481 | |
| 482 | |
| 483 | @pytest.mark.parametrize( |
nothing calls this directly
no test coverage detected