Test batching a DGLGraph and a batched DGLGraph.
(idtype)
| 123 | |
| 124 | @parametrize_idtype |
| 125 | def test_batching_batched(idtype): |
| 126 | """Test batching a DGLGraph and a batched DGLGraph.""" |
| 127 | g1 = dgl.heterograph( |
| 128 | { |
| 129 | ("user", "follows", "user"): ([0, 1], [1, 2]), |
| 130 | ("user", "plays", "game"): ([0, 1], [0, 0]), |
| 131 | }, |
| 132 | idtype=idtype, |
| 133 | device=F.ctx(), |
| 134 | ) |
| 135 | g2 = dgl.heterograph( |
| 136 | { |
| 137 | ("user", "follows", "user"): ([0, 1], [1, 2]), |
| 138 | ("user", "plays", "game"): ([0, 1], [0, 0]), |
| 139 | }, |
| 140 | idtype=idtype, |
| 141 | device=F.ctx(), |
| 142 | ) |
| 143 | bg1 = dgl.batch([g1, g2]) |
| 144 | g3 = dgl.heterograph( |
| 145 | { |
| 146 | ("user", "follows", "user"): ([0], [1]), |
| 147 | ("user", "plays", "game"): ([1], [0]), |
| 148 | }, |
| 149 | idtype=idtype, |
| 150 | device=F.ctx(), |
| 151 | ) |
| 152 | bg2 = dgl.batch([bg1, g3]) |
| 153 | assert bg2.idtype == idtype |
| 154 | assert bg2.device == F.ctx() |
| 155 | assert bg2.ntypes == g3.ntypes |
| 156 | assert bg2.etypes == g3.etypes |
| 157 | assert bg2.canonical_etypes == g3.canonical_etypes |
| 158 | assert bg2.batch_size == 3 |
| 159 | |
| 160 | # Test number of nodes |
| 161 | for ntype in bg2.ntypes: |
| 162 | assert F.asnumpy(bg2.batch_num_nodes(ntype)).tolist() == [ |
| 163 | g1.num_nodes(ntype), |
| 164 | g2.num_nodes(ntype), |
| 165 | g3.num_nodes(ntype), |
| 166 | ] |
| 167 | assert bg2.num_nodes(ntype) == ( |
| 168 | g1.num_nodes(ntype) + g2.num_nodes(ntype) + g3.num_nodes(ntype) |
| 169 | ) |
| 170 | |
| 171 | # Test number of edges |
| 172 | for etype in bg2.canonical_etypes: |
| 173 | assert F.asnumpy(bg2.batch_num_edges(etype)).tolist() == [ |
| 174 | g1.num_edges(etype), |
| 175 | g2.num_edges(etype), |
| 176 | g3.num_edges(etype), |
| 177 | ] |
| 178 | assert bg2.num_edges(etype) == ( |
| 179 | g1.num_edges(etype) + g2.num_edges(etype) + g3.num_edges(etype) |
| 180 | ) |
| 181 | |
| 182 | # Test relabeled nodes |
nothing calls this directly
no test coverage detected