MCPcopy
hub / github.com/dmlc/dgl / _test_sample_neighbors

Function _test_sample_neighbors

tests/python/common/sampling/test_sampling.py:563–709  ·  view source on GitHub ↗
(hypersparse, prob, fused)

Source from the content-addressed store, hash-verified

561
562
563def _test_sample_neighbors(hypersparse, prob, fused):
564 g, hg = _gen_neighbor_sampling_test_graph(hypersparse, False)
565
566 def _test1(p, replace):
567 subg = sample_neighbors_fusing_mode[fused](
568 g, [0, 1], -1, prob=p, replace=replace
569 )
570 if not fused:
571 assert subg.num_nodes() == g.num_nodes()
572 u, v = subg.edges()
573 if fused:
574 u, v = subg.srcdata[dgl.NID][u], subg.dstdata[dgl.NID][v]
575 u_ans, v_ans, e_ans = g.in_edges([0, 1], form="all")
576 if p is not None:
577 emask = F.gather_row(g.edata[p], e_ans)
578 if p == "prob":
579 emask = emask != 0
580 u_ans = F.boolean_mask(u_ans, emask)
581 v_ans = F.boolean_mask(v_ans, emask)
582 uv = set(zip(F.asnumpy(u), F.asnumpy(v)))
583 uv_ans = set(zip(F.asnumpy(u_ans), F.asnumpy(v_ans)))
584 assert uv == uv_ans
585
586 for i in range(10):
587 subg = sample_neighbors_fusing_mode[fused](
588 g, [0, 1], 2, prob=p, replace=replace
589 )
590 if not fused:
591 assert subg.num_nodes() == g.num_nodes()
592
593 assert subg.num_edges() == 4
594 u, v = subg.edges()
595 if fused:
596 u, v = subg.srcdata[dgl.NID][u], subg.dstdata[dgl.NID][v]
597
598 assert set(F.asnumpy(F.unique(v))) == {0, 1}
599 assert F.array_equal(
600 F.astype(g.has_edges_between(u, v), F.int64),
601 F.ones((4,), dtype=F.int64),
602 )
603 assert F.array_equal(g.edge_ids(u, v), subg.edata[dgl.EID])
604 edge_set = set(zip(list(F.asnumpy(u)), list(F.asnumpy(v))))
605 if not replace:
606 # check no duplication
607 assert len(edge_set) == 4
608 if p is not None:
609 assert not (3, 0) in edge_set
610 assert not (3, 1) in edge_set
611
612 _test1(prob, True) # w/ replacement, uniform
613 _test1(prob, False) # w/o replacement, uniform
614
615 def _test2(p, replace): # fanout > #neighbors
616 subg = sample_neighbors_fusing_mode[fused](
617 g, [0, 2], -1, prob=p, replace=replace
618 )
619 if not fused:
620 assert subg.num_nodes() == g.num_nodes()

Callers 3

Calls 5

_test1Function · 0.85
_test2Function · 0.85
_test3Function · 0.85
num_edgesMethod · 0.45

Tested by

no test coverage detected