(idtype)
| 236 | |
| 237 | @parametrize_idtype |
| 238 | def test_pull_multi_fallback(idtype): |
| 239 | # create a graph with zero in degree nodes |
| 240 | g = dgl.graph([]) |
| 241 | g = g.astype(idtype).to(F.ctx()) |
| 242 | g.add_nodes(10) |
| 243 | for i in range(1, 9): |
| 244 | g.add_edges(0, i) |
| 245 | g.add_edges(i, 9) |
| 246 | g.ndata["h"] = F.randn((10, D)) |
| 247 | g.edata["w1"] = F.randn((16,)) |
| 248 | g.edata["w2"] = F.randn((16, D)) |
| 249 | |
| 250 | def _mfunc_hxw1(edges): |
| 251 | return {"m1": edges.src["h"] * F.unsqueeze(edges.data["w1"], 1)} |
| 252 | |
| 253 | def _mfunc_hxw2(edges): |
| 254 | return {"m2": edges.src["h"] * edges.data["w2"]} |
| 255 | |
| 256 | def _rfunc_m1(nodes): |
| 257 | return {"o1": F.sum(nodes.mailbox["m1"], 1)} |
| 258 | |
| 259 | def _rfunc_m2(nodes): |
| 260 | return {"o2": F.sum(nodes.mailbox["m2"], 1)} |
| 261 | |
| 262 | def _rfunc_m1max(nodes): |
| 263 | return {"o3": F.max(nodes.mailbox["m1"], 1)} |
| 264 | |
| 265 | def _afunc(nodes): |
| 266 | ret = {} |
| 267 | for k, v in nodes.data.items(): |
| 268 | if k.startswith("o"): |
| 269 | ret[k] = 2 * v |
| 270 | return ret |
| 271 | |
| 272 | # nodes to pull |
| 273 | def _pull_nodes(nodes): |
| 274 | # compute ground truth |
| 275 | g.pull(nodes, _mfunc_hxw1, _rfunc_m1, _afunc) |
| 276 | o1 = g.ndata.pop("o1") |
| 277 | g.pull(nodes, _mfunc_hxw2, _rfunc_m2, _afunc) |
| 278 | o2 = g.ndata.pop("o2") |
| 279 | g.pull(nodes, _mfunc_hxw1, _rfunc_m1max, _afunc) |
| 280 | o3 = g.ndata.pop("o3") |
| 281 | # v2v spmv |
| 282 | g.pull( |
| 283 | nodes, |
| 284 | fn.u_mul_e("h", "w1", "m1"), |
| 285 | fn.sum(msg="m1", out="o1"), |
| 286 | _afunc, |
| 287 | ) |
| 288 | assert F.allclose(o1, g.ndata.pop("o1")) |
| 289 | # v2v fallback to e2v |
| 290 | g.pull( |
| 291 | nodes, |
| 292 | fn.u_mul_e("h", "w2", "m2"), |
| 293 | fn.sum(msg="m2", out="o2"), |
| 294 | _afunc, |
| 295 | ) |
no test coverage detected