| 2829 | |
| 2830 | @parametrize_idtype |
| 2831 | def test_module_gcnnorm(idtype): |
| 2832 | g = dgl.heterograph( |
| 2833 | { |
| 2834 | ("A", "r1", "A"): ([0, 1, 2], [0, 0, 1]), |
| 2835 | ("A", "r2", "B"): ([0, 0], [1, 1]), |
| 2836 | ("B", "r3", "B"): ([0, 1, 2], [0, 0, 1]), |
| 2837 | }, |
| 2838 | idtype=idtype, |
| 2839 | device=F.ctx(), |
| 2840 | ) |
| 2841 | g.edges["r3"].data["w"] = F.tensor([0.1, 0.2, 0.3]) |
| 2842 | transform = dgl.GCNNorm() |
| 2843 | new_g = transform(g) |
| 2844 | assert "w" not in new_g.edges[("A", "r2", "B")].data |
| 2845 | assert F.allclose( |
| 2846 | new_g.edges[("A", "r1", "A")].data["w"], |
| 2847 | F.tensor([1.0 / 2, 1.0 / math.sqrt(2), 0.0]), |
| 2848 | ) |
| 2849 | assert F.allclose( |
| 2850 | new_g.edges[("B", "r3", "B")].data["w"], |
| 2851 | F.tensor([1.0 / 3, 2.0 / 3, 0.0]), |
| 2852 | ) |
| 2853 | |
| 2854 | |
| 2855 | @unittest.skipIf( |