| 29 | @pytest.mark.parametrize("g", get_cases(["homo"], exclude=["dglgraph"])) |
| 30 | @pytest.mark.parametrize("reducer", ["sum", "max", "mean"]) |
| 31 | def test_reduce_readout(g, idtype, reducer): |
| 32 | g = g.astype(idtype).to(F.ctx()) |
| 33 | g.ndata["h"] = F.randn((g.num_nodes(), 3)) |
| 34 | g.edata["h"] = F.randn((g.num_edges(), 2)) |
| 35 | |
| 36 | # Test.1: node readout |
| 37 | x = dgl.readout_nodes(g, "h", op=reducer) |
| 38 | # check correctness |
| 39 | subg = dgl.unbatch(g) |
| 40 | subx = [] |
| 41 | for sg in subg: |
| 42 | sx = dgl.readout_nodes(sg, "h", op=reducer) |
| 43 | subx.append(sx) |
| 44 | assert F.allclose(x, F.cat(subx, dim=0)) |
| 45 | |
| 46 | x = getattr(dgl, "{}_nodes".format(reducer))(g, "h") |
| 47 | # check correctness |
| 48 | subg = dgl.unbatch(g) |
| 49 | subx = [] |
| 50 | for sg in subg: |
| 51 | sx = getattr(dgl, "{}_nodes".format(reducer))(sg, "h") |
| 52 | subx.append(sx) |
| 53 | assert F.allclose(x, F.cat(subx, dim=0)) |
| 54 | |
| 55 | # Test.2: edge readout |
| 56 | x = dgl.readout_edges(g, "h", op=reducer) |
| 57 | # check correctness |
| 58 | subg = dgl.unbatch(g) |
| 59 | subx = [] |
| 60 | for sg in subg: |
| 61 | sx = dgl.readout_edges(sg, "h", op=reducer) |
| 62 | subx.append(sx) |
| 63 | assert F.allclose(x, F.cat(subx, dim=0)) |
| 64 | |
| 65 | x = getattr(dgl, "{}_edges".format(reducer))(g, "h") |
| 66 | # check correctness |
| 67 | subg = dgl.unbatch(g) |
| 68 | subx = [] |
| 69 | for sg in subg: |
| 70 | sx = getattr(dgl, "{}_edges".format(reducer))(sg, "h") |
| 71 | subx.append(sx) |
| 72 | assert F.allclose(x, F.cat(subx, dim=0)) |
| 73 | |
| 74 | |
| 75 | @parametrize_idtype |