MCPcopy
hub / github.com/dmlc/dgl / test_view1

Function test_view1

tests/python/common/test_heterograph.py:879–1008  ·  view source on GitHub ↗
(idtype)

Source from the content-addressed store, hash-verified

877
878@parametrize_idtype
879def test_view1(idtype):
880 # test relation view
881 HG = create_test_heterograph(idtype)
882 ntypes = ["user", "game", "developer"]
883 canonical_etypes = [
884 ("user", "follows", "user"),
885 ("user", "plays", "game"),
886 ("user", "wishes", "game"),
887 ("developer", "develops", "game"),
888 ]
889 etypes = ["follows", "plays", "wishes", "develops"]
890
891 def _test_query():
892 for etype in etypes:
893 utype, _, vtype = HG.to_canonical_etype(etype)
894 g = HG[etype]
895 srcs, dsts = edges[etype]
896 for src, dst in zip(srcs, dsts):
897 assert g.has_edges_between(src, dst)
898 assert F.asnumpy(g.has_edges_between(srcs, dsts)).all()
899
900 srcs, dsts = negative_edges[etype]
901 for src, dst in zip(srcs, dsts):
902 assert not g.has_edges_between(src, dst)
903 assert not F.asnumpy(g.has_edges_between(srcs, dsts)).any()
904
905 srcs, dsts = edges[etype]
906 n_edges = len(srcs)
907
908 # predecessors & in_edges & in_degree
909 pred = [s for s, d in zip(srcs, dsts) if d == 0]
910 assert set(F.asnumpy(g.predecessors(0)).tolist()) == set(pred)
911 u, v = g.in_edges([0])
912 assert F.asnumpy(v).tolist() == [0] * len(pred)
913 assert set(F.asnumpy(u).tolist()) == set(pred)
914 assert g.in_degrees(0) == len(pred)
915
916 # successors & out_edges & out_degree
917 succ = [d for s, d in zip(srcs, dsts) if s == 0]
918 assert set(F.asnumpy(g.successors(0)).tolist()) == set(succ)
919 u, v = g.out_edges([0])
920 assert F.asnumpy(u).tolist() == [0] * len(succ)
921 assert set(F.asnumpy(v).tolist()) == set(succ)
922 assert g.out_degrees(0) == len(succ)
923
924 # edge_ids
925 for i, (src, dst) in enumerate(zip(srcs, dsts)):
926 assert g.edge_ids(src, dst, etype=etype) == i
927 _, _, eid = g.edge_ids(src, dst, etype=etype, return_uv=True)
928 assert eid == i
929 assert F.asnumpy(g.edge_ids(srcs, dsts)).tolist() == list(
930 range(n_edges)
931 )
932 u, v, e = g.edge_ids(srcs, dsts, return_uv=True)
933 u, v, e = F.asnumpy(u), F.asnumpy(v), F.asnumpy(e)
934 assert u[e].tolist() == srcs
935 assert v[e].tolist() == dsts
936

Callers

nothing calls this directly

Calls 5

_test_queryFunction · 0.85
create_test_heterographFunction · 0.70
num_nodesMethod · 0.45
nodesMethod · 0.45
edgesMethod · 0.45

Tested by

no test coverage detected