| 12 | |
| 13 | @parametrize_idtype |
| 14 | def test_sum_case1(idtype): |
| 15 | # NOTE: If you want to update this test case, remember to update the docstring |
| 16 | # example too!!! |
| 17 | g1 = dgl.graph(([0, 1], [1, 0]), idtype=idtype, device=F.ctx()) |
| 18 | g1.ndata["h"] = F.tensor([1.0, 2.0]) |
| 19 | g2 = dgl.graph(([0, 1], [1, 2]), idtype=idtype, device=F.ctx()) |
| 20 | g2.ndata["h"] = F.tensor([1.0, 2.0, 3.0]) |
| 21 | bg = dgl.batch([g1, g2]) |
| 22 | bg.ndata["w"] = F.tensor([0.1, 0.2, 0.1, 0.5, 0.2]) |
| 23 | assert F.allclose(F.tensor([3.0]), dgl.sum_nodes(g1, "h")) |
| 24 | assert F.allclose(F.tensor([3.0, 6.0]), dgl.sum_nodes(bg, "h")) |
| 25 | assert F.allclose(F.tensor([0.5, 1.7]), dgl.sum_nodes(bg, "h", "w")) |
| 26 | |
| 27 | |
| 28 | @parametrize_idtype |