(idtype)
| 68 | |
| 69 | @parametrize_idtype |
| 70 | def test_batch_unbatch1(idtype): |
| 71 | t1 = tree1(idtype) |
| 72 | t2 = tree2(idtype) |
| 73 | b1 = dgl.batch([t1, t2]) |
| 74 | b2 = dgl.batch([t2, b1]) |
| 75 | assert b2.num_nodes() == 15 |
| 76 | assert b2.num_edges() == 12 |
| 77 | assert b2.batch_size == 3 |
| 78 | assert F.allclose(b2.batch_num_nodes(), F.tensor([5, 5, 5])) |
| 79 | assert F.allclose(b2.batch_num_edges(), F.tensor([4, 4, 4])) |
| 80 | |
| 81 | s1, s2, s3 = dgl.unbatch(b2) |
| 82 | assert F.allclose(t2.ndata["h"], s1.ndata["h"]) |
| 83 | assert F.allclose(t2.edata["h"], s1.edata["h"]) |
| 84 | assert F.allclose(t1.ndata["h"], s2.ndata["h"]) |
| 85 | assert F.allclose(t1.edata["h"], s2.edata["h"]) |
| 86 | assert F.allclose(t2.ndata["h"], s3.ndata["h"]) |
| 87 | assert F.allclose(t2.edata["h"], s3.edata["h"]) |
| 88 | |
| 89 | |
| 90 | @unittest.skipIf( |
nothing calls this directly
no test coverage detected