| 2737 | |
| 2738 | @parametrize_idtype |
| 2739 | def test_module_add_metapaths(idtype): |
| 2740 | g = dgl.heterograph( |
| 2741 | { |
| 2742 | ("person", "author", "paper"): ([0, 0, 1], [1, 2, 2]), |
| 2743 | ("paper", "accepted", "venue"): ([1], [0]), |
| 2744 | ("paper", "rejected", "venue"): ([2], [1]), |
| 2745 | }, |
| 2746 | idtype=idtype, |
| 2747 | device=F.ctx(), |
| 2748 | ) |
| 2749 | g.nodes["venue"].data["h"] = F.randn((g.num_nodes("venue"), 2)) |
| 2750 | g.edges["author"].data["h"] = F.randn((g.num_edges("author"), 3)) |
| 2751 | |
| 2752 | # Case1: keep_orig_edges is True |
| 2753 | metapaths = { |
| 2754 | "accepted": [ |
| 2755 | ("person", "author", "paper"), |
| 2756 | ("paper", "accepted", "venue"), |
| 2757 | ], |
| 2758 | "rejected": [ |
| 2759 | ("person", "author", "paper"), |
| 2760 | ("paper", "rejected", "venue"), |
| 2761 | ], |
| 2762 | } |
| 2763 | transform = dgl.AddMetaPaths(metapaths) |
| 2764 | new_g = transform(g) |
| 2765 | assert new_g.device == g.device |
| 2766 | assert new_g.idtype == g.idtype |
| 2767 | assert new_g.ntypes == g.ntypes |
| 2768 | assert set(new_g.canonical_etypes) == { |
| 2769 | ("person", "author", "paper"), |
| 2770 | ("paper", "accepted", "venue"), |
| 2771 | ("paper", "rejected", "venue"), |
| 2772 | ("person", "accepted", "venue"), |
| 2773 | ("person", "rejected", "venue"), |
| 2774 | } |
| 2775 | for nty in new_g.ntypes: |
| 2776 | assert new_g.num_nodes(nty) == g.num_nodes(nty) |
| 2777 | for ety in g.canonical_etypes: |
| 2778 | assert new_g.num_edges(ety) == g.num_edges(ety) |
| 2779 | assert F.allclose( |
| 2780 | g.nodes["venue"].data["h"], new_g.nodes["venue"].data["h"] |
| 2781 | ) |
| 2782 | assert F.allclose( |
| 2783 | g.edges["author"].data["h"], new_g.edges["author"].data["h"] |
| 2784 | ) |
| 2785 | |
| 2786 | src, dst = new_g.edges(etype=("person", "accepted", "venue")) |
| 2787 | eset = set(zip(list(F.asnumpy(src)), list(F.asnumpy(dst)))) |
| 2788 | assert eset == {(0, 0)} |
| 2789 | |
| 2790 | src, dst = new_g.edges(etype=("person", "rejected", "venue")) |
| 2791 | eset = set(zip(list(F.asnumpy(src)), list(F.asnumpy(dst)))) |
| 2792 | assert eset == {(0, 1), (1, 1)} |
| 2793 | |
| 2794 | # Case2: keep_orig_edges is False |
| 2795 | transform = dgl.AddMetaPaths(metapaths, keep_orig_edges=False) |
| 2796 | new_g = transform(g) |