(idtype)
| 49 | |
| 50 | @parametrize_idtype |
| 51 | def test_batch_unbatch(idtype): |
| 52 | t1 = tree1(idtype) |
| 53 | t2 = tree2(idtype) |
| 54 | |
| 55 | bg = dgl.batch([t1, t2]) |
| 56 | assert bg.num_nodes() == 10 |
| 57 | assert bg.num_edges() == 8 |
| 58 | assert bg.batch_size == 2 |
| 59 | assert F.allclose(bg.batch_num_nodes(), F.tensor([5, 5])) |
| 60 | assert F.allclose(bg.batch_num_edges(), F.tensor([4, 4])) |
| 61 | |
| 62 | tt1, tt2 = dgl.unbatch(bg) |
| 63 | assert F.allclose(t1.ndata["h"], tt1.ndata["h"]) |
| 64 | assert F.allclose(t1.edata["h"], tt1.edata["h"]) |
| 65 | assert F.allclose(t2.ndata["h"], tt2.ndata["h"]) |
| 66 | assert F.allclose(t2.edata["h"], tt2.edata["h"]) |
| 67 | |
| 68 | |
| 69 | @parametrize_idtype |
nothing calls this directly
no test coverage detected