(nodes)
| 271 | |
| 272 | # nodes to pull |
| 273 | def _pull_nodes(nodes): |
| 274 | # compute ground truth |
| 275 | g.pull(nodes, _mfunc_hxw1, _rfunc_m1, _afunc) |
| 276 | o1 = g.ndata.pop("o1") |
| 277 | g.pull(nodes, _mfunc_hxw2, _rfunc_m2, _afunc) |
| 278 | o2 = g.ndata.pop("o2") |
| 279 | g.pull(nodes, _mfunc_hxw1, _rfunc_m1max, _afunc) |
| 280 | o3 = g.ndata.pop("o3") |
| 281 | # v2v spmv |
| 282 | g.pull( |
| 283 | nodes, |
| 284 | fn.u_mul_e("h", "w1", "m1"), |
| 285 | fn.sum(msg="m1", out="o1"), |
| 286 | _afunc, |
| 287 | ) |
| 288 | assert F.allclose(o1, g.ndata.pop("o1")) |
| 289 | # v2v fallback to e2v |
| 290 | g.pull( |
| 291 | nodes, |
| 292 | fn.u_mul_e("h", "w2", "m2"), |
| 293 | fn.sum(msg="m2", out="o2"), |
| 294 | _afunc, |
| 295 | ) |
| 296 | assert F.allclose(o2, g.ndata.pop("o2")) |
| 297 | |
| 298 | # test#1: non-0deg nodes |
| 299 | nodes = [1, 2, 9] |
no test coverage detected