Test module of node/edge frames of batched/unbatched DGLGraphs. Also address the bug mentioned in https://github.com/dmlc/dgl/issues/1475.
(idtype)
| 93 | ) |
| 94 | @parametrize_idtype |
| 95 | def test_batch_unbatch_frame(idtype): |
| 96 | """Test module of node/edge frames of batched/unbatched DGLGraphs. |
| 97 | Also address the bug mentioned in https://github.com/dmlc/dgl/issues/1475. |
| 98 | """ |
| 99 | t1 = tree1(idtype) |
| 100 | t2 = tree2(idtype) |
| 101 | N1 = t1.num_nodes() |
| 102 | E1 = t1.num_edges() |
| 103 | N2 = t2.num_nodes() |
| 104 | E2 = t2.num_edges() |
| 105 | D = 10 |
| 106 | t1.ndata["h"] = F.randn((N1, D)) |
| 107 | t1.edata["h"] = F.randn((E1, D)) |
| 108 | t2.ndata["h"] = F.randn((N2, D)) |
| 109 | t2.edata["h"] = F.randn((E2, D)) |
| 110 | |
| 111 | b1 = dgl.batch([t1, t2]) |
| 112 | b2 = dgl.batch([t2]) |
| 113 | b1.ndata["h"][:N1] = F.zeros((N1, D)) |
| 114 | b1.edata["h"][:E1] = F.zeros((E1, D)) |
| 115 | b2.ndata["h"][:N2] = F.zeros((N2, D)) |
| 116 | b2.edata["h"][:E2] = F.zeros((E2, D)) |
| 117 | assert not F.allclose(t1.ndata["h"], F.zeros((N1, D))) |
| 118 | assert not F.allclose(t1.edata["h"], F.zeros((E1, D))) |
| 119 | assert not F.allclose(t2.ndata["h"], F.zeros((N2, D))) |
| 120 | assert not F.allclose(t2.edata["h"], F.zeros((E2, D))) |
| 121 | |
| 122 | g1, g2 = dgl.unbatch(b1) |
| 123 | (_g2,) = dgl.unbatch(b2) |
| 124 | assert F.allclose(g1.ndata["h"], F.zeros((N1, D))) |
| 125 | assert F.allclose(g1.edata["h"], F.zeros((E1, D))) |
| 126 | assert F.allclose(g2.ndata["h"], t2.ndata["h"]) |
| 127 | assert F.allclose(g2.edata["h"], t2.edata["h"]) |
| 128 | assert F.allclose(_g2.ndata["h"], F.zeros((N2, D))) |
| 129 | assert F.allclose(_g2.edata["h"], F.zeros((E2, D))) |
| 130 | |
| 131 | |
| 132 | @parametrize_idtype |