(create_func, index, replace, bias)
| 511 | @pytest.mark.parametrize("replace", [False, True]) |
| 512 | @pytest.mark.parametrize("bias", [False, True]) |
| 513 | def test_sample_rowwise(create_func, index, replace, bias): |
| 514 | ctx = F.ctx() |
| 515 | shape = (5, 5) |
| 516 | sample_dim = 0 |
| 517 | sample_num = 3 |
| 518 | A = create_func(shape, 10, ctx) |
| 519 | A = val_like(A, torch.abs(A.val)) |
| 520 | |
| 521 | index = torch.tensor(index).to(ctx) |
| 522 | |
| 523 | A_sample = A.sample(sample_dim, sample_num, index, replace, bias) |
| 524 | A_dense = sparse_matrix_to_dense(A) |
| 525 | A_sample_to_dense = sparse_matrix_to_dense(A_sample) |
| 526 | |
| 527 | ans_shape = (index.size(0), shape[1]) |
| 528 | # Verify sample elements in origin rows |
| 529 | for i, row in enumerate(list(index)): |
| 530 | ans_ele = list(A_dense[row, :].nonzero().reshape(-1)) |
| 531 | ret_ele = list(A_sample_to_dense[i, :].nonzero().reshape(-1)) |
| 532 | for e in ret_ele: |
| 533 | assert e in ans_ele |
| 534 | if replace: |
| 535 | # The number of sample elements in one row should be equal to |
| 536 | # 'sample_num' if the row is not empty otherwise should be |
| 537 | # equal to 0. |
| 538 | assert list(A_sample.row).count(torch.tensor(i)) == ( |
| 539 | sample_num if len(ans_ele) != 0 else 0 |
| 540 | ) |
| 541 | else: |
| 542 | assert len(ret_ele) == min(sample_num, len(ans_ele)) |
| 543 | |
| 544 | assert A_sample.shape == ans_shape |
| 545 | if not replace: |
| 546 | assert not A_sample.has_duplicate() |
| 547 | |
| 548 | |
| 549 | @pytest.mark.parametrize( |
nothing calls this directly
no test coverage detected