Test the features of batched DGLGraphs
(idtype)
| 335 | ) |
| 336 | @parametrize_idtype |
| 337 | def test_empty_relation(idtype): |
| 338 | """Test the features of batched DGLGraphs""" |
| 339 | g1 = dgl.heterograph( |
| 340 | { |
| 341 | ("user", "follows", "user"): ([0, 1], [1, 2]), |
| 342 | ("user", "plays", "game"): ([], []), |
| 343 | }, |
| 344 | idtype=idtype, |
| 345 | device=F.ctx(), |
| 346 | ) |
| 347 | g1.nodes["user"].data["h1"] = F.tensor([[0.0], [1.0], [2.0]]) |
| 348 | g1.nodes["user"].data["h2"] = F.tensor([[3.0], [4.0], [5.0]]) |
| 349 | g1.edges["follows"].data["h1"] = F.tensor([[0.0], [1.0]]) |
| 350 | g1.edges["follows"].data["h2"] = F.tensor([[2.0], [3.0]]) |
| 351 | |
| 352 | g2 = dgl.heterograph( |
| 353 | { |
| 354 | ("user", "follows", "user"): ([0, 1], [1, 2]), |
| 355 | ("user", "plays", "game"): ([0, 1], [0, 0]), |
| 356 | }, |
| 357 | idtype=idtype, |
| 358 | device=F.ctx(), |
| 359 | ) |
| 360 | g2.nodes["user"].data["h1"] = F.tensor([[0.0], [1.0], [2.0]]) |
| 361 | g2.nodes["user"].data["h2"] = F.tensor([[3.0], [4.0], [5.0]]) |
| 362 | g2.nodes["game"].data["h1"] = F.tensor([[0.0]]) |
| 363 | g2.nodes["game"].data["h2"] = F.tensor([[1.0]]) |
| 364 | g2.edges["follows"].data["h1"] = F.tensor([[0.0], [1.0]]) |
| 365 | g2.edges["follows"].data["h2"] = F.tensor([[2.0], [3.0]]) |
| 366 | g2.edges["plays"].data["h1"] = F.tensor([[0.0], [1.0]]) |
| 367 | |
| 368 | bg = dgl.batch([g1, g2]) |
| 369 | |
| 370 | # Test number of nodes |
| 371 | for ntype in bg.ntypes: |
| 372 | assert F.asnumpy(bg.batch_num_nodes(ntype)).tolist() == [ |
| 373 | g1.num_nodes(ntype), |
| 374 | g2.num_nodes(ntype), |
| 375 | ] |
| 376 | |
| 377 | # Test number of edges |
| 378 | for etype in bg.canonical_etypes: |
| 379 | assert F.asnumpy(bg.batch_num_edges(etype)).tolist() == [ |
| 380 | g1.num_edges(etype), |
| 381 | g2.num_edges(etype), |
| 382 | ] |
| 383 | |
| 384 | # Test features |
| 385 | assert F.allclose( |
| 386 | bg.nodes["user"].data["h1"], |
| 387 | F.cat( |
| 388 | [g1.nodes["user"].data["h1"], g2.nodes["user"].data["h1"]], dim=0 |
| 389 | ), |
| 390 | ) |
| 391 | assert F.allclose( |
| 392 | bg.nodes["user"].data["h2"], |
| 393 | F.cat( |
| 394 | [g1.nodes["user"].data["h2"], g2.nodes["user"].data["h2"]], dim=0 |
nothing calls this directly
no test coverage detected