| 76 | @pytest.mark.parametrize("g", get_cases(["homo"], exclude=["dglgraph"])) |
| 77 | @pytest.mark.parametrize("reducer", ["sum", "max", "mean"]) |
| 78 | def test_weighted_reduce_readout(g, idtype, reducer): |
| 79 | g = g.astype(idtype).to(F.ctx()) |
| 80 | g.ndata["h"] = F.randn((g.num_nodes(), 3)) |
| 81 | g.ndata["w"] = F.randn((g.num_nodes(), 1)) |
| 82 | g.edata["h"] = F.randn((g.num_edges(), 2)) |
| 83 | g.edata["w"] = F.randn((g.num_edges(), 1)) |
| 84 | |
| 85 | # Test.1: node readout |
| 86 | x = dgl.readout_nodes(g, "h", "w", op=reducer) |
| 87 | # check correctness |
| 88 | subg = dgl.unbatch(g) |
| 89 | subx = [] |
| 90 | for sg in subg: |
| 91 | sx = dgl.readout_nodes(sg, "h", "w", op=reducer) |
| 92 | subx.append(sx) |
| 93 | assert F.allclose(x, F.cat(subx, dim=0)) |
| 94 | |
| 95 | x = getattr(dgl, "{}_nodes".format(reducer))(g, "h", "w") |
| 96 | # check correctness |
| 97 | subg = dgl.unbatch(g) |
| 98 | subx = [] |
| 99 | for sg in subg: |
| 100 | sx = getattr(dgl, "{}_nodes".format(reducer))(sg, "h", "w") |
| 101 | subx.append(sx) |
| 102 | assert F.allclose(x, F.cat(subx, dim=0)) |
| 103 | |
| 104 | # Test.2: edge readout |
| 105 | x = dgl.readout_edges(g, "h", "w", op=reducer) |
| 106 | # check correctness |
| 107 | subg = dgl.unbatch(g) |
| 108 | subx = [] |
| 109 | for sg in subg: |
| 110 | sx = dgl.readout_edges(sg, "h", "w", op=reducer) |
| 111 | subx.append(sx) |
| 112 | assert F.allclose(x, F.cat(subx, dim=0)) |
| 113 | |
| 114 | x = getattr(dgl, "{}_edges".format(reducer))(g, "h", "w") |
| 115 | # check correctness |
| 116 | subg = dgl.unbatch(g) |
| 117 | subx = [] |
| 118 | for sg in subg: |
| 119 | sx = getattr(dgl, "{}_edges".format(reducer))(sg, "h", "w") |
| 120 | subx.append(sx) |
| 121 | assert F.allclose(x, F.cat(subx, dim=0)) |
| 122 | |
| 123 | |
| 124 | @parametrize_idtype |