(p, replace)
| 613 | _test1(prob, False) # w/o replacement, uniform |
| 614 | |
| 615 | def _test2(p, replace): # fanout > #neighbors |
| 616 | subg = sample_neighbors_fusing_mode[fused]( |
| 617 | g, [0, 2], -1, prob=p, replace=replace |
| 618 | ) |
| 619 | if not fused: |
| 620 | assert subg.num_nodes() == g.num_nodes() |
| 621 | u, v = subg.edges() |
| 622 | if fused: |
| 623 | u, v = subg.srcdata[dgl.NID][u], subg.dstdata[dgl.NID][v] |
| 624 | u_ans, v_ans, e_ans = g.in_edges([0, 2], form="all") |
| 625 | if p is not None: |
| 626 | emask = F.gather_row(g.edata[p], e_ans) |
| 627 | if p == "prob": |
| 628 | emask = emask != 0 |
| 629 | u_ans = F.boolean_mask(u_ans, emask) |
| 630 | v_ans = F.boolean_mask(v_ans, emask) |
| 631 | uv = set(zip(F.asnumpy(u), F.asnumpy(v))) |
| 632 | uv_ans = set(zip(F.asnumpy(u_ans), F.asnumpy(v_ans))) |
| 633 | assert uv == uv_ans |
| 634 | |
| 635 | for i in range(10): |
| 636 | subg = sample_neighbors_fusing_mode[fused]( |
| 637 | g, [0, 2], 2, prob=p, replace=replace |
| 638 | ) |
| 639 | if not fused: |
| 640 | assert subg.num_nodes() == g.num_nodes() |
| 641 | num_edges = 4 if replace else 3 |
| 642 | assert subg.num_edges() == num_edges |
| 643 | u, v = subg.edges() |
| 644 | if fused: |
| 645 | u, v = subg.srcdata[dgl.NID][u], subg.dstdata[dgl.NID][v] |
| 646 | assert set(F.asnumpy(F.unique(v))) == {0, 2} |
| 647 | assert F.array_equal( |
| 648 | F.astype(g.has_edges_between(u, v), F.int64), |
| 649 | F.ones((num_edges,), dtype=F.int64), |
| 650 | ) |
| 651 | assert F.array_equal(g.edge_ids(u, v), subg.edata[dgl.EID]) |
| 652 | edge_set = set(zip(list(F.asnumpy(u)), list(F.asnumpy(v)))) |
| 653 | if not replace: |
| 654 | # check no duplication |
| 655 | assert len(edge_set) == num_edges |
| 656 | if p is not None: |
| 657 | assert not (3, 0) in edge_set |
| 658 | |
| 659 | _test2(prob, True) # w/ replacement, uniform |
| 660 | _test2(prob, False) # w/o replacement, uniform |
no test coverage detected