MCPcopy
hub / github.com/dmlc/dgl / test_sample_rowwise

Function test_sample_rowwise

tests/python/pytorch/sparse/test_sparse_matrix.py:513–546  ·  view source on GitHub ↗
(create_func, index, replace, bias)

Source from the content-addressed store, hash-verified

511@pytest.mark.parametrize("replace", [False, True])
512@pytest.mark.parametrize("bias", [False, True])
513def test_sample_rowwise(create_func, index, replace, bias):
514 ctx = F.ctx()
515 shape = (5, 5)
516 sample_dim = 0
517 sample_num = 3
518 A = create_func(shape, 10, ctx)
519 A = val_like(A, torch.abs(A.val))
520
521 index = torch.tensor(index).to(ctx)
522
523 A_sample = A.sample(sample_dim, sample_num, index, replace, bias)
524 A_dense = sparse_matrix_to_dense(A)
525 A_sample_to_dense = sparse_matrix_to_dense(A_sample)
526
527 ans_shape = (index.size(0), shape[1])
528 # Verify sample elements in origin rows
529 for i, row in enumerate(list(index)):
530 ans_ele = list(A_dense[row, :].nonzero().reshape(-1))
531 ret_ele = list(A_sample_to_dense[i, :].nonzero().reshape(-1))
532 for e in ret_ele:
533 assert e in ans_ele
534 if replace:
535 # The number of sample elements in one row should be equal to
536 # 'sample_num' if the row is not empty otherwise should be
537 # equal to 0.
538 assert list(A_sample.row).count(torch.tensor(i)) == (
539 sample_num if len(ans_ele) != 0 else 0
540 )
541 else:
542 assert len(ret_ele) == min(sample_num, len(ans_ele))
543
544 assert A_sample.shape == ans_shape
545 if not replace:
546 assert not A_sample.has_duplicate()
547
548
549@pytest.mark.parametrize(

Callers

nothing calls this directly

Calls 10

val_likeFunction · 0.90
sparse_matrix_to_denseFunction · 0.85
nonzeroMethod · 0.80
has_duplicateMethod · 0.80
minFunction · 0.50
ctxMethod · 0.45
toMethod · 0.45
sampleMethod · 0.45
sizeMethod · 0.45
countMethod · 0.45

Tested by

no test coverage detected