MCPcopy
hub / github.com/dmlc/dgl / test_sample_columnwise

Function test_sample_columnwise

tests/python/pytorch/sparse/test_sparse_matrix.py:555–588  ·  view source on GitHub ↗
(create_func, index, replace, bias)

Source from the content-addressed store, hash-verified

553@pytest.mark.parametrize("replace", [False, True])
554@pytest.mark.parametrize("bias", [False, True])
555def test_sample_columnwise(create_func, index, replace, bias):
556 ctx = F.ctx()
557 shape = (5, 5)
558 sample_dim = 1
559 sample_num = 3
560 A = create_func(shape, 10, ctx)
561 A = val_like(A, torch.abs(A.val))
562
563 index = torch.tensor(index).to(ctx)
564
565 A_sample = A.sample(sample_dim, sample_num, index, replace, bias)
566 A_dense = sparse_matrix_to_dense(A)
567 A_sample_to_dense = sparse_matrix_to_dense(A_sample)
568
569 ans_shape = (shape[0], index.size(0))
570 # Verify sample elements in origin columns
571 for i, col in enumerate(list(index)):
572 ans_ele = list(A_dense[:, col].nonzero().reshape(-1))
573 ret_ele = list(A_sample_to_dense[:, i].nonzero().reshape(-1))
574 for e in ret_ele:
575 assert e in ans_ele
576 if replace:
577 # The number of sample elements in one column should be equal to
578 # 'sample_num' if the column is not empty otherwise should be
579 # equal to 0.
580 assert list(A_sample.col).count(torch.tensor(i)) == (
581 sample_num if len(ans_ele) != 0 else 0
582 )
583 else:
584 assert len(ret_ele) == min(sample_num, len(ans_ele))
585
586 assert A_sample.shape == ans_shape
587 if not replace:
588 assert not A_sample.has_duplicate()
589
590
591def test_print():

Callers

nothing calls this directly

Calls 10

val_likeFunction · 0.90
sparse_matrix_to_denseFunction · 0.85
nonzeroMethod · 0.80
has_duplicateMethod · 0.80
minFunction · 0.50
ctxMethod · 0.45
toMethod · 0.45
sampleMethod · 0.45
sizeMethod · 0.45
countMethod · 0.45

Tested by

no test coverage detected