MCPcopy
hub / github.com/dmlc/dgl / test_module_add_metapaths

Function test_module_add_metapaths

tests/python/common/transforms/test_transform.py:2739–2813  ·  view source on GitHub ↗
(idtype)

Source from the content-addressed store, hash-verified

2737
2738@parametrize_idtype
2739def test_module_add_metapaths(idtype):
2740 g = dgl.heterograph(
2741 {
2742 ("person", "author", "paper"): ([0, 0, 1], [1, 2, 2]),
2743 ("paper", "accepted", "venue"): ([1], [0]),
2744 ("paper", "rejected", "venue"): ([2], [1]),
2745 },
2746 idtype=idtype,
2747 device=F.ctx(),
2748 )
2749 g.nodes["venue"].data["h"] = F.randn((g.num_nodes("venue"), 2))
2750 g.edges["author"].data["h"] = F.randn((g.num_edges("author"), 3))
2751
2752 # Case1: keep_orig_edges is True
2753 metapaths = {
2754 "accepted": [
2755 ("person", "author", "paper"),
2756 ("paper", "accepted", "venue"),
2757 ],
2758 "rejected": [
2759 ("person", "author", "paper"),
2760 ("paper", "rejected", "venue"),
2761 ],
2762 }
2763 transform = dgl.AddMetaPaths(metapaths)
2764 new_g = transform(g)
2765 assert new_g.device == g.device
2766 assert new_g.idtype == g.idtype
2767 assert new_g.ntypes == g.ntypes
2768 assert set(new_g.canonical_etypes) == {
2769 ("person", "author", "paper"),
2770 ("paper", "accepted", "venue"),
2771 ("paper", "rejected", "venue"),
2772 ("person", "accepted", "venue"),
2773 ("person", "rejected", "venue"),
2774 }
2775 for nty in new_g.ntypes:
2776 assert new_g.num_nodes(nty) == g.num_nodes(nty)
2777 for ety in g.canonical_etypes:
2778 assert new_g.num_edges(ety) == g.num_edges(ety)
2779 assert F.allclose(
2780 g.nodes["venue"].data["h"], new_g.nodes["venue"].data["h"]
2781 )
2782 assert F.allclose(
2783 g.edges["author"].data["h"], new_g.edges["author"].data["h"]
2784 )
2785
2786 src, dst = new_g.edges(etype=("person", "accepted", "venue"))
2787 eset = set(zip(list(F.asnumpy(src)), list(F.asnumpy(dst))))
2788 assert eset == {(0, 0)}
2789
2790 src, dst = new_g.edges(etype=("person", "rejected", "venue"))
2791 eset = set(zip(list(F.asnumpy(src)), list(F.asnumpy(dst))))
2792 assert eset == {(0, 1), (1, 1)}
2793
2794 # Case2: keep_orig_edges is False
2795 transform = dgl.AddMetaPaths(metapaths, keep_orig_edges=False)
2796 new_g = transform(g)

Callers

nothing calls this directly

Calls 6

transformFunction · 0.85
asnumpyMethod · 0.80
ctxMethod · 0.45
num_nodesMethod · 0.45
num_edgesMethod · 0.45
edgesMethod · 0.45

Tested by

no test coverage detected