| 2914 | ) |
| 2915 | @parametrize_idtype |
| 2916 | def test_module_heat_kernel(idtype): |
| 2917 | # Case1: directed graph |
| 2918 | g = dgl.graph( |
| 2919 | ([0, 1, 2, 3, 4], [2, 3, 4, 5, 3]), idtype=idtype, device=F.ctx() |
| 2920 | ) |
| 2921 | g.ndata["h"] = F.randn((6, 2)) |
| 2922 | transform = dgl.HeatKernel(avg_degree=1) |
| 2923 | new_g = transform(g) |
| 2924 | assert new_g.idtype == g.idtype |
| 2925 | assert new_g.device == g.device |
| 2926 | assert new_g.num_nodes() == g.num_nodes() |
| 2927 | assert F.allclose(g.ndata["h"], new_g.ndata["h"]) |
| 2928 | assert "w" in new_g.edata |
| 2929 | |
| 2930 | # Case2: weighted undirected graph |
| 2931 | g = dgl.graph(([0, 1, 2, 3], [1, 0, 3, 2]), idtype=idtype, device=F.ctx()) |
| 2932 | g.edata["w"] = F.tensor([0.1, 0.2, 0.3, 0.4]) |
| 2933 | new_g = transform(g) |
| 2934 | src, dst = new_g.edges() |
| 2935 | eset = set(zip(list(F.asnumpy(src)), list(F.asnumpy(dst)))) |
| 2936 | assert eset == {(0, 0), (1, 1), (2, 2), (3, 3)} |
| 2937 | |
| 2938 | |
| 2939 | @unittest.skipIf( |