(create_func, index, replace, bias)
| 553 | @pytest.mark.parametrize("replace", [False, True]) |
| 554 | @pytest.mark.parametrize("bias", [False, True]) |
| 555 | def test_sample_columnwise(create_func, index, replace, bias): |
| 556 | ctx = F.ctx() |
| 557 | shape = (5, 5) |
| 558 | sample_dim = 1 |
| 559 | sample_num = 3 |
| 560 | A = create_func(shape, 10, ctx) |
| 561 | A = val_like(A, torch.abs(A.val)) |
| 562 | |
| 563 | index = torch.tensor(index).to(ctx) |
| 564 | |
| 565 | A_sample = A.sample(sample_dim, sample_num, index, replace, bias) |
| 566 | A_dense = sparse_matrix_to_dense(A) |
| 567 | A_sample_to_dense = sparse_matrix_to_dense(A_sample) |
| 568 | |
| 569 | ans_shape = (shape[0], index.size(0)) |
| 570 | # Verify sample elements in origin columns |
| 571 | for i, col in enumerate(list(index)): |
| 572 | ans_ele = list(A_dense[:, col].nonzero().reshape(-1)) |
| 573 | ret_ele = list(A_sample_to_dense[:, i].nonzero().reshape(-1)) |
| 574 | for e in ret_ele: |
| 575 | assert e in ans_ele |
| 576 | if replace: |
| 577 | # The number of sample elements in one column should be equal to |
| 578 | # 'sample_num' if the column is not empty otherwise should be |
| 579 | # equal to 0. |
| 580 | assert list(A_sample.col).count(torch.tensor(i)) == ( |
| 581 | sample_num if len(ans_ele) != 0 else 0 |
| 582 | ) |
| 583 | else: |
| 584 | assert len(ret_ele) == min(sample_num, len(ans_ele)) |
| 585 | |
| 586 | assert A_sample.shape == ans_shape |
| 587 | if not replace: |
| 588 | assert not A_sample.has_duplicate() |
| 589 | |
| 590 | |
| 591 | def test_print(): |
nothing calls this directly
no test coverage detected