| 2857 | ) |
| 2858 | @parametrize_idtype |
| 2859 | def test_module_ppr(idtype): |
| 2860 | g = dgl.graph( |
| 2861 | ([0, 1, 2, 3, 4], [2, 3, 4, 5, 3]), idtype=idtype, device=F.ctx() |
| 2862 | ) |
| 2863 | g.ndata["h"] = F.randn((6, 2)) |
| 2864 | transform = dgl.PPR(avg_degree=2) |
| 2865 | new_g = transform(g) |
| 2866 | assert new_g.idtype == g.idtype |
| 2867 | assert new_g.device == g.device |
| 2868 | assert new_g.num_nodes() == g.num_nodes() |
| 2869 | src, dst = new_g.edges() |
| 2870 | eset = set(zip(list(F.asnumpy(src)), list(F.asnumpy(dst)))) |
| 2871 | assert eset == { |
| 2872 | (0, 0), |
| 2873 | (0, 2), |
| 2874 | (0, 4), |
| 2875 | (1, 1), |
| 2876 | (1, 3), |
| 2877 | (1, 5), |
| 2878 | (2, 2), |
| 2879 | (2, 3), |
| 2880 | (2, 4), |
| 2881 | (3, 3), |
| 2882 | (3, 5), |
| 2883 | (4, 3), |
| 2884 | (4, 4), |
| 2885 | (4, 5), |
| 2886 | (5, 5), |
| 2887 | } |
| 2888 | assert F.allclose(g.ndata["h"], new_g.ndata["h"]) |
| 2889 | assert "w" in new_g.edata |
| 2890 | |
| 2891 | # Prior edge weights |
| 2892 | g.edata["w"] = F.tensor([0.1, 0.2, 0.3, 0.4, 0.5]) |
| 2893 | new_g = transform(g) |
| 2894 | src, dst = new_g.edges() |
| 2895 | eset = set(zip(list(F.asnumpy(src)), list(F.asnumpy(dst)))) |
| 2896 | assert eset == { |
| 2897 | (0, 0), |
| 2898 | (1, 1), |
| 2899 | (1, 3), |
| 2900 | (2, 2), |
| 2901 | (2, 3), |
| 2902 | (2, 4), |
| 2903 | (3, 3), |
| 2904 | (3, 5), |
| 2905 | (4, 3), |
| 2906 | (4, 4), |
| 2907 | (4, 5), |
| 2908 | (5, 5), |
| 2909 | } |
| 2910 | |
| 2911 | |
| 2912 | @unittest.skipIf( |