(p, replace)
| 564 | g, hg = _gen_neighbor_sampling_test_graph(hypersparse, False) |
| 565 | |
| 566 | def _test1(p, replace): |
| 567 | subg = sample_neighbors_fusing_mode[fused]( |
| 568 | g, [0, 1], -1, prob=p, replace=replace |
| 569 | ) |
| 570 | if not fused: |
| 571 | assert subg.num_nodes() == g.num_nodes() |
| 572 | u, v = subg.edges() |
| 573 | if fused: |
| 574 | u, v = subg.srcdata[dgl.NID][u], subg.dstdata[dgl.NID][v] |
| 575 | u_ans, v_ans, e_ans = g.in_edges([0, 1], form="all") |
| 576 | if p is not None: |
| 577 | emask = F.gather_row(g.edata[p], e_ans) |
| 578 | if p == "prob": |
| 579 | emask = emask != 0 |
| 580 | u_ans = F.boolean_mask(u_ans, emask) |
| 581 | v_ans = F.boolean_mask(v_ans, emask) |
| 582 | uv = set(zip(F.asnumpy(u), F.asnumpy(v))) |
| 583 | uv_ans = set(zip(F.asnumpy(u_ans), F.asnumpy(v_ans))) |
| 584 | assert uv == uv_ans |
| 585 | |
| 586 | for i in range(10): |
| 587 | subg = sample_neighbors_fusing_mode[fused]( |
| 588 | g, [0, 1], 2, prob=p, replace=replace |
| 589 | ) |
| 590 | if not fused: |
| 591 | assert subg.num_nodes() == g.num_nodes() |
| 592 | |
| 593 | assert subg.num_edges() == 4 |
| 594 | u, v = subg.edges() |
| 595 | if fused: |
| 596 | u, v = subg.srcdata[dgl.NID][u], subg.dstdata[dgl.NID][v] |
| 597 | |
| 598 | assert set(F.asnumpy(F.unique(v))) == {0, 1} |
| 599 | assert F.array_equal( |
| 600 | F.astype(g.has_edges_between(u, v), F.int64), |
| 601 | F.ones((4,), dtype=F.int64), |
| 602 | ) |
| 603 | assert F.array_equal(g.edge_ids(u, v), subg.edata[dgl.EID]) |
| 604 | edge_set = set(zip(list(F.asnumpy(u)), list(F.asnumpy(v)))) |
| 605 | if not replace: |
| 606 | # check no duplication |
| 607 | assert len(edge_set) == 4 |
| 608 | if p is not None: |
| 609 | assert not (3, 0) in edge_set |
| 610 | assert not (3, 1) in edge_set |
| 611 | |
| 612 | _test1(prob, True) # w/ replacement, uniform |
| 613 | _test1(prob, False) # w/o replacement, uniform |
no test coverage detected