(idtype)
| 519 | |
| 520 | @parametrize_idtype |
| 521 | def test_local_var(idtype): |
| 522 | g = dgl.graph(([0, 1, 2, 3], [1, 2, 3, 4]), idtype=idtype, device=F.ctx()) |
| 523 | g.ndata["h"] = F.zeros((g.num_nodes(), 3)) |
| 524 | g.edata["w"] = F.zeros((g.num_edges(), 4)) |
| 525 | |
| 526 | # test override |
| 527 | def foo(g): |
| 528 | g = g.local_var() |
| 529 | g.ndata["h"] = F.ones((g.num_nodes(), 3)) |
| 530 | g.edata["w"] = F.ones((g.num_edges(), 4)) |
| 531 | |
| 532 | foo(g) |
| 533 | assert F.allclose(g.ndata["h"], F.zeros((g.num_nodes(), 3))) |
| 534 | assert F.allclose(g.edata["w"], F.zeros((g.num_edges(), 4))) |
| 535 | |
| 536 | # test out-place update |
| 537 | def foo(g): |
| 538 | g = g.local_var() |
| 539 | g.nodes[[2, 3]].data["h"] = F.ones((2, 3)) |
| 540 | g.edges[[2, 3]].data["w"] = F.ones((2, 4)) |
| 541 | |
| 542 | foo(g) |
| 543 | assert F.allclose(g.ndata["h"], F.zeros((g.num_nodes(), 3))) |
| 544 | assert F.allclose(g.edata["w"], F.zeros((g.num_edges(), 4))) |
| 545 | |
| 546 | # test out-place update 2 |
| 547 | def foo(g): |
| 548 | g = g.local_var() |
| 549 | g.apply_nodes(lambda nodes: {"h": nodes.data["h"] + 10}, [2, 3]) |
| 550 | g.apply_edges(lambda edges: {"w": edges.data["w"] + 10}, [2, 3]) |
| 551 | |
| 552 | foo(g) |
| 553 | assert F.allclose(g.ndata["h"], F.zeros((g.num_nodes(), 3))) |
| 554 | assert F.allclose(g.edata["w"], F.zeros((g.num_edges(), 4))) |
| 555 | |
| 556 | # test auto-pop |
| 557 | def foo(g): |
| 558 | g = g.local_var() |
| 559 | g.ndata["hh"] = F.ones((g.num_nodes(), 3)) |
| 560 | g.edata["ww"] = F.ones((g.num_edges(), 4)) |
| 561 | |
| 562 | foo(g) |
| 563 | assert "hh" not in g.ndata |
| 564 | assert "ww" not in g.edata |
| 565 | |
| 566 | # test initializer1 |
| 567 | g = dgl.graph(([0, 1], [1, 1]), idtype=idtype, device=F.ctx()) |
| 568 | g.set_n_initializer(dgl.init.zero_initializer) |
| 569 | |
| 570 | def foo(g): |
| 571 | g = g.local_var() |
| 572 | g.nodes[0].data["h"] = F.ones((1, 1)) |
| 573 | assert F.allclose(g.ndata["h"], F.tensor([[1.0], [0.0]])) |
| 574 | |
| 575 | foo(g) |
| 576 | |
| 577 | # test initializer2 |
| 578 | def foo_e_initializer(shape, dtype, ctx, id_range): |
nothing calls this directly
no test coverage detected