(idtype)
| 165 | |
| 166 | @parametrize_idtype |
| 167 | def test_batch_propagate(idtype): |
| 168 | t1 = tree1(idtype) |
| 169 | t2 = tree2(idtype) |
| 170 | |
| 171 | bg = dgl.batch([t1, t2]) |
| 172 | _mfunc = lambda edges: {"m": edges.src["h"]} |
| 173 | _rfunc = lambda nodes: {"h": F.sum(nodes.mailbox["m"], 1)} |
| 174 | # get leaves. |
| 175 | |
| 176 | order = [] |
| 177 | |
| 178 | # step 1 |
| 179 | u = [3, 4, 2 + 5, 0 + 5] |
| 180 | v = [1, 1, 4 + 5, 4 + 5] |
| 181 | order.append((u, v)) |
| 182 | |
| 183 | # step 2 |
| 184 | u = [1, 2, 4 + 5, 3 + 5] |
| 185 | v = [0, 0, 1 + 5, 1 + 5] |
| 186 | order.append((u, v)) |
| 187 | |
| 188 | bg.prop_edges(order, _mfunc, _rfunc) |
| 189 | t1, t2 = dgl.unbatch(bg) |
| 190 | |
| 191 | assert F.asnumpy(t1.ndata["h"][0]) == 9 |
| 192 | assert F.asnumpy(t2.ndata["h"][1]) == 5 |
| 193 | |
| 194 | |
| 195 | @parametrize_idtype |
nothing calls this directly
no test coverage detected