MCPcopy
hub / github.com/dmlc/dgl / _test_sample_neighbors_topk

Function _test_sample_neighbors_topk

tests/python/common/sampling/test_sampling.py:974–1055  ·  view source on GitHub ↗
(hypersparse)

Source from the content-addressed store, hash-verified

972
973
974def _test_sample_neighbors_topk(hypersparse):
975 g, hg = _gen_neighbor_topk_test_graph(hypersparse, False)
976
977 def _test1():
978 subg = dgl.sampling.select_topk(g, -1, "weight", [0, 1])
979 assert subg.num_nodes() == g.num_nodes()
980 u, v = subg.edges()
981 u_ans, v_ans = subg.in_edges([0, 1])
982 uv = set(zip(F.asnumpy(u), F.asnumpy(v)))
983 uv_ans = set(zip(F.asnumpy(u_ans), F.asnumpy(v_ans)))
984 assert uv == uv_ans
985
986 subg = dgl.sampling.select_topk(g, 2, "weight", [0, 1])
987 assert subg.num_nodes() == g.num_nodes()
988 assert subg.num_edges() == 4
989 u, v = subg.edges()
990 edge_set = set(zip(list(F.asnumpy(u)), list(F.asnumpy(v))))
991 assert F.array_equal(g.edge_ids(u, v), subg.edata[dgl.EID])
992 assert edge_set == {(2, 0), (1, 0), (2, 1), (3, 1)}
993
994 _test1()
995
996 def _test2(): # k > #neighbors
997 subg = dgl.sampling.select_topk(g, -1, "weight", [0, 2])
998 assert subg.num_nodes() == g.num_nodes()
999 u, v = subg.edges()
1000 u_ans, v_ans = subg.in_edges([0, 2])
1001 uv = set(zip(F.asnumpy(u), F.asnumpy(v)))
1002 uv_ans = set(zip(F.asnumpy(u_ans), F.asnumpy(v_ans)))
1003 assert uv == uv_ans
1004
1005 subg = dgl.sampling.select_topk(g, 2, "weight", [0, 2])
1006 assert subg.num_nodes() == g.num_nodes()
1007 assert subg.num_edges() == 3
1008 u, v = subg.edges()
1009 assert F.array_equal(g.edge_ids(u, v), subg.edata[dgl.EID])
1010 edge_set = set(zip(list(F.asnumpy(u)), list(F.asnumpy(v))))
1011 assert edge_set == {(2, 0), (1, 0), (0, 2)}
1012
1013 _test2()
1014
1015 def _test3():
1016 subg = dgl.sampling.select_topk(
1017 hg, 2, "weight", {"user": [0, 1], "game": 0}
1018 )
1019 assert len(subg.ntypes) == 3
1020 assert len(subg.etypes) == 4
1021 u, v = subg["follow"].edges()
1022 edge_set = set(zip(list(F.asnumpy(u)), list(F.asnumpy(v))))
1023 assert F.array_equal(
1024 hg["follow"].edge_ids(u, v), subg["follow"].edata[dgl.EID]
1025 )
1026 assert edge_set == {(2, 0), (1, 0), (2, 1), (3, 1)}
1027 u, v = subg["play"].edges()
1028 edge_set = set(zip(list(F.asnumpy(u)), list(F.asnumpy(v))))
1029 assert F.array_equal(
1030 hg["play"].edge_ids(u, v), subg["play"].edata[dgl.EID]
1031 )

Callers 1

Calls 5

_test1Function · 0.85
_test2Function · 0.85
_test3Function · 0.85
num_edgesMethod · 0.45

Tested by

no test coverage detected