| 710 | |
| 711 | |
| 712 | def _test_sample_labors(hypersparse, prob): |
| 713 | g, hg = _gen_neighbor_sampling_test_graph(hypersparse, False) |
| 714 | |
| 715 | # test with seed nodes [0, 1] |
| 716 | def _test1(p): |
| 717 | subg = dgl.sampling.sample_labors(g, [0, 1], -1, prob=p)[0] |
| 718 | assert subg.num_nodes() == g.num_nodes() |
| 719 | u, v = subg.edges() |
| 720 | u_ans, v_ans, e_ans = g.in_edges([0, 1], form="all") |
| 721 | if p is not None: |
| 722 | emask = F.gather_row(g.edata[p], e_ans) |
| 723 | if p == "prob": |
| 724 | emask = emask != 0 |
| 725 | u_ans = F.boolean_mask(u_ans, emask) |
| 726 | v_ans = F.boolean_mask(v_ans, emask) |
| 727 | uv = set(zip(F.asnumpy(u), F.asnumpy(v))) |
| 728 | uv_ans = set(zip(F.asnumpy(u_ans), F.asnumpy(v_ans))) |
| 729 | assert uv == uv_ans |
| 730 | |
| 731 | for i in range(10): |
| 732 | subg = dgl.sampling.sample_labors(g, [0, 1], 2, prob=p)[0] |
| 733 | assert subg.num_nodes() == g.num_nodes() |
| 734 | assert subg.num_edges() >= 0 |
| 735 | u, v = subg.edges() |
| 736 | assert set(F.asnumpy(F.unique(v))).issubset({0, 1}) |
| 737 | assert F.array_equal( |
| 738 | F.astype(g.has_edges_between(u, v), F.int64), |
| 739 | F.ones((subg.num_edges(),), dtype=F.int64), |
| 740 | ) |
| 741 | assert F.array_equal(g.edge_ids(u, v), subg.edata[dgl.EID]) |
| 742 | edge_set = set(zip(list(F.asnumpy(u)), list(F.asnumpy(v)))) |
| 743 | # check no duplication |
| 744 | assert len(edge_set) == subg.num_edges() |
| 745 | if p is not None: |
| 746 | assert not (3, 0) in edge_set |
| 747 | assert not (3, 1) in edge_set |
| 748 | |
| 749 | _test1(prob) |
| 750 | |
| 751 | # test with seed nodes [0, 2] |
| 752 | def _test2(p): |
| 753 | subg = dgl.sampling.sample_labors(g, [0, 2], -1, prob=p)[0] |
| 754 | assert subg.num_nodes() == g.num_nodes() |
| 755 | u, v = subg.edges() |
| 756 | u_ans, v_ans, e_ans = g.in_edges([0, 2], form="all") |
| 757 | if p is not None: |
| 758 | emask = F.gather_row(g.edata[p], e_ans) |
| 759 | if p == "prob": |
| 760 | emask = emask != 0 |
| 761 | u_ans = F.boolean_mask(u_ans, emask) |
| 762 | v_ans = F.boolean_mask(v_ans, emask) |
| 763 | uv = set(zip(F.asnumpy(u), F.asnumpy(v))) |
| 764 | uv_ans = set(zip(F.asnumpy(u_ans), F.asnumpy(v_ans))) |
| 765 | assert uv == uv_ans |
| 766 | |
| 767 | for i in range(10): |
| 768 | subg = dgl.sampling.sample_labors(g, [0, 2], 2, prob=p)[0] |
| 769 | assert subg.num_nodes() == g.num_nodes() |