(hypersparse)
| 972 | |
| 973 | |
| 974 | def _test_sample_neighbors_topk(hypersparse): |
| 975 | g, hg = _gen_neighbor_topk_test_graph(hypersparse, False) |
| 976 | |
| 977 | def _test1(): |
| 978 | subg = dgl.sampling.select_topk(g, -1, "weight", [0, 1]) |
| 979 | assert subg.num_nodes() == g.num_nodes() |
| 980 | u, v = subg.edges() |
| 981 | u_ans, v_ans = subg.in_edges([0, 1]) |
| 982 | uv = set(zip(F.asnumpy(u), F.asnumpy(v))) |
| 983 | uv_ans = set(zip(F.asnumpy(u_ans), F.asnumpy(v_ans))) |
| 984 | assert uv == uv_ans |
| 985 | |
| 986 | subg = dgl.sampling.select_topk(g, 2, "weight", [0, 1]) |
| 987 | assert subg.num_nodes() == g.num_nodes() |
| 988 | assert subg.num_edges() == 4 |
| 989 | u, v = subg.edges() |
| 990 | edge_set = set(zip(list(F.asnumpy(u)), list(F.asnumpy(v)))) |
| 991 | assert F.array_equal(g.edge_ids(u, v), subg.edata[dgl.EID]) |
| 992 | assert edge_set == {(2, 0), (1, 0), (2, 1), (3, 1)} |
| 993 | |
| 994 | _test1() |
| 995 | |
| 996 | def _test2(): # k > #neighbors |
| 997 | subg = dgl.sampling.select_topk(g, -1, "weight", [0, 2]) |
| 998 | assert subg.num_nodes() == g.num_nodes() |
| 999 | u, v = subg.edges() |
| 1000 | u_ans, v_ans = subg.in_edges([0, 2]) |
| 1001 | uv = set(zip(F.asnumpy(u), F.asnumpy(v))) |
| 1002 | uv_ans = set(zip(F.asnumpy(u_ans), F.asnumpy(v_ans))) |
| 1003 | assert uv == uv_ans |
| 1004 | |
| 1005 | subg = dgl.sampling.select_topk(g, 2, "weight", [0, 2]) |
| 1006 | assert subg.num_nodes() == g.num_nodes() |
| 1007 | assert subg.num_edges() == 3 |
| 1008 | u, v = subg.edges() |
| 1009 | assert F.array_equal(g.edge_ids(u, v), subg.edata[dgl.EID]) |
| 1010 | edge_set = set(zip(list(F.asnumpy(u)), list(F.asnumpy(v)))) |
| 1011 | assert edge_set == {(2, 0), (1, 0), (0, 2)} |
| 1012 | |
| 1013 | _test2() |
| 1014 | |
| 1015 | def _test3(): |
| 1016 | subg = dgl.sampling.select_topk( |
| 1017 | hg, 2, "weight", {"user": [0, 1], "game": 0} |
| 1018 | ) |
| 1019 | assert len(subg.ntypes) == 3 |
| 1020 | assert len(subg.etypes) == 4 |
| 1021 | u, v = subg["follow"].edges() |
| 1022 | edge_set = set(zip(list(F.asnumpy(u)), list(F.asnumpy(v)))) |
| 1023 | assert F.array_equal( |
| 1024 | hg["follow"].edge_ids(u, v), subg["follow"].edata[dgl.EID] |
| 1025 | ) |
| 1026 | assert edge_set == {(2, 0), (1, 0), (2, 1), (3, 1)} |
| 1027 | u, v = subg["play"].edges() |
| 1028 | edge_set = set(zip(list(F.asnumpy(u)), list(F.asnumpy(v)))) |
| 1029 | assert F.array_equal( |
| 1030 | hg["play"].edge_ids(u, v), subg["play"].edata[dgl.EID] |
| 1031 | ) |
no test coverage detected