MCPcopy
hub / github.com/dmlc/dgl / _test_sample_labors

Function _test_sample_labors

tests/python/common/sampling/test_sampling.py:712–823  ·  view source on GitHub ↗
(hypersparse, prob)

Source from the content-addressed store, hash-verified

710
711
712def _test_sample_labors(hypersparse, prob):
713 g, hg = _gen_neighbor_sampling_test_graph(hypersparse, False)
714
715 # test with seed nodes [0, 1]
716 def _test1(p):
717 subg = dgl.sampling.sample_labors(g, [0, 1], -1, prob=p)[0]
718 assert subg.num_nodes() == g.num_nodes()
719 u, v = subg.edges()
720 u_ans, v_ans, e_ans = g.in_edges([0, 1], form="all")
721 if p is not None:
722 emask = F.gather_row(g.edata[p], e_ans)
723 if p == "prob":
724 emask = emask != 0
725 u_ans = F.boolean_mask(u_ans, emask)
726 v_ans = F.boolean_mask(v_ans, emask)
727 uv = set(zip(F.asnumpy(u), F.asnumpy(v)))
728 uv_ans = set(zip(F.asnumpy(u_ans), F.asnumpy(v_ans)))
729 assert uv == uv_ans
730
731 for i in range(10):
732 subg = dgl.sampling.sample_labors(g, [0, 1], 2, prob=p)[0]
733 assert subg.num_nodes() == g.num_nodes()
734 assert subg.num_edges() >= 0
735 u, v = subg.edges()
736 assert set(F.asnumpy(F.unique(v))).issubset({0, 1})
737 assert F.array_equal(
738 F.astype(g.has_edges_between(u, v), F.int64),
739 F.ones((subg.num_edges(),), dtype=F.int64),
740 )
741 assert F.array_equal(g.edge_ids(u, v), subg.edata[dgl.EID])
742 edge_set = set(zip(list(F.asnumpy(u)), list(F.asnumpy(v))))
743 # check no duplication
744 assert len(edge_set) == subg.num_edges()
745 if p is not None:
746 assert not (3, 0) in edge_set
747 assert not (3, 1) in edge_set
748
749 _test1(prob)
750
751 # test with seed nodes [0, 2]
752 def _test2(p):
753 subg = dgl.sampling.sample_labors(g, [0, 2], -1, prob=p)[0]
754 assert subg.num_nodes() == g.num_nodes()
755 u, v = subg.edges()
756 u_ans, v_ans, e_ans = g.in_edges([0, 2], form="all")
757 if p is not None:
758 emask = F.gather_row(g.edata[p], e_ans)
759 if p == "prob":
760 emask = emask != 0
761 u_ans = F.boolean_mask(u_ans, emask)
762 v_ans = F.boolean_mask(v_ans, emask)
763 uv = set(zip(F.asnumpy(u), F.asnumpy(v)))
764 uv_ans = set(zip(F.asnumpy(u_ans), F.asnumpy(v_ans)))
765 assert uv == uv_ans
766
767 for i in range(10):
768 subg = dgl.sampling.sample_labors(g, [0, 2], 2, prob=p)[0]
769 assert subg.num_nodes() == g.num_nodes()

Callers 2

test_sample_labors_probFunction · 0.85

Calls 6

_test1Function · 0.85
_test2Function · 0.85
_test3Function · 0.85
num_nodesMethod · 0.45
num_edgesMethod · 0.45

Tested by

no test coverage detected