MCPcopy Index your code
hub / github.com/dmlc/dgl / _test2

Function _test2

tests/python/common/sampling/test_sampling.py:615–657  ·  view source on GitHub ↗
(p, replace)

Source from the content-addressed store, hash-verified

613 _test1(prob, False) # w/o replacement, uniform
614
615 def _test2(p, replace): # fanout > #neighbors
616 subg = sample_neighbors_fusing_mode[fused](
617 g, [0, 2], -1, prob=p, replace=replace
618 )
619 if not fused:
620 assert subg.num_nodes() == g.num_nodes()
621 u, v = subg.edges()
622 if fused:
623 u, v = subg.srcdata[dgl.NID][u], subg.dstdata[dgl.NID][v]
624 u_ans, v_ans, e_ans = g.in_edges([0, 2], form="all")
625 if p is not None:
626 emask = F.gather_row(g.edata[p], e_ans)
627 if p == "prob":
628 emask = emask != 0
629 u_ans = F.boolean_mask(u_ans, emask)
630 v_ans = F.boolean_mask(v_ans, emask)
631 uv = set(zip(F.asnumpy(u), F.asnumpy(v)))
632 uv_ans = set(zip(F.asnumpy(u_ans), F.asnumpy(v_ans)))
633 assert uv == uv_ans
634
635 for i in range(10):
636 subg = sample_neighbors_fusing_mode[fused](
637 g, [0, 2], 2, prob=p, replace=replace
638 )
639 if not fused:
640 assert subg.num_nodes() == g.num_nodes()
641 num_edges = 4 if replace else 3
642 assert subg.num_edges() == num_edges
643 u, v = subg.edges()
644 if fused:
645 u, v = subg.srcdata[dgl.NID][u], subg.dstdata[dgl.NID][v]
646 assert set(F.asnumpy(F.unique(v))) == {0, 2}
647 assert F.array_equal(
648 F.astype(g.has_edges_between(u, v), F.int64),
649 F.ones((num_edges,), dtype=F.int64),
650 )
651 assert F.array_equal(g.edge_ids(u, v), subg.edata[dgl.EID])
652 edge_set = set(zip(list(F.asnumpy(u)), list(F.asnumpy(v))))
653 if not replace:
654 # check no duplication
655 assert len(edge_set) == num_edges
656 if p is not None:
657 assert not (3, 0) in edge_set
658
659 _test2(prob, True) # w/ replacement, uniform
660 _test2(prob, False) # w/o replacement, uniform

Callers 5

_test_sample_neighborsFunction · 0.85
_test_sample_laborsFunction · 0.85

Calls 9

asnumpyMethod · 0.80
num_nodesMethod · 0.45
edgesMethod · 0.45
in_edgesMethod · 0.45
num_edgesMethod · 0.45
astypeMethod · 0.45
has_edges_betweenMethod · 0.45
edge_idsMethod · 0.45
out_edgesMethod · 0.45

Tested by

no test coverage detected