| 191 | @parametrize_idtype |
| 192 | @pytest.mark.parametrize("g", get_cases(["homo"], exclude=["dglgraph"])) |
| 193 | def test_softmax(g, idtype): |
| 194 | g = g.astype(idtype).to(F.ctx()) |
| 195 | g.ndata["h"] = F.randn((g.num_nodes(), 3)) |
| 196 | g.edata["h"] = F.randn((g.num_edges(), 2)) |
| 197 | |
| 198 | # Test.1: node readout |
| 199 | x = dgl.softmax_nodes(g, "h") |
| 200 | subg = dgl.unbatch(g) |
| 201 | subx = [] |
| 202 | for sg in subg: |
| 203 | subx.append(F.softmax(sg.ndata["h"], dim=0)) |
| 204 | assert F.allclose(x, F.cat(subx, dim=0)) |
| 205 | |
| 206 | # Test.2: edge readout |
| 207 | x = dgl.softmax_edges(g, "h") |
| 208 | subg = dgl.unbatch(g) |
| 209 | subx = [] |
| 210 | for sg in subg: |
| 211 | subx.append(F.softmax(sg.edata["h"], dim=0)) |
| 212 | assert F.allclose(x, F.cat(subx, dim=0)) |
| 213 | |
| 214 | |
| 215 | @parametrize_idtype |