(idtype)
| 432 | |
| 433 | @parametrize_idtype |
| 434 | def test_pull_0deg(idtype): |
| 435 | g = dgl.graph(([0], [1]), idtype=idtype, device=F.ctx()) |
| 436 | |
| 437 | def _message(edges): |
| 438 | return {"m": edges.src["h"]} |
| 439 | |
| 440 | def _reduce(nodes): |
| 441 | return {"x": nodes.data["h"] + F.sum(nodes.mailbox["m"], 1)} |
| 442 | |
| 443 | def _apply(nodes): |
| 444 | return {"x": nodes.data["x"] * 2} |
| 445 | |
| 446 | def _init2(shape, dtype, ctx, ids): |
| 447 | return 2 + F.zeros(shape, dtype, ctx) |
| 448 | |
| 449 | g.set_n_initializer(_init2, "x") |
| 450 | # test#1: pull both 0deg and non-0deg nodes |
| 451 | old = F.randn((2, 5)) |
| 452 | g.ndata["h"] = old |
| 453 | g.pull([0, 1], _message, _reduce, _apply) |
| 454 | new = g.ndata["x"] |
| 455 | # 0deg check: initialized with the func and got applied |
| 456 | assert F.allclose(new[0], F.full_1d(5, 4, dtype=F.float32)) |
| 457 | # non-0deg check |
| 458 | assert F.allclose(new[1], F.sum(old, 0) * 2) |
| 459 | |
| 460 | # test#2: pull only 0deg node |
| 461 | old = F.randn((2, 5)) |
| 462 | g.ndata["h"] = old |
| 463 | # Intercepting the warning: The input graph for the user-defined edge |
| 464 | # function does not contain valid edges |
| 465 | with warnings.catch_warnings(): |
| 466 | warnings.simplefilter("ignore", category=UserWarning) |
| 467 | g.pull(0, _message, _reduce, lambda nodes: {"h": nodes.data["h"] * 2}) |
| 468 | |
| 469 | new = g.ndata["h"] |
| 470 | # 0deg check: fallback to apply |
| 471 | assert F.allclose(new[0], 2 * old[0]) |
| 472 | # non-0deg check: not touched |
| 473 | assert F.allclose(new[1], old[1]) |
| 474 | |
| 475 | |
| 476 | def test_dynamic_addition(): |
nothing calls this directly
no test coverage detected