| 455 | |
| 456 | @parametrize_idtype |
| 457 | def test_slice_batch(idtype): |
| 458 | g1 = dgl.heterograph( |
| 459 | { |
| 460 | ("user", "follows", "user"): ([0, 1], [1, 2]), |
| 461 | ("user", "plays", "game"): ([], []), |
| 462 | ("user", "follows", "game"): ([0, 0], [1, 4]), |
| 463 | }, |
| 464 | idtype=idtype, |
| 465 | device=F.ctx(), |
| 466 | ) |
| 467 | g2 = dgl.heterograph( |
| 468 | { |
| 469 | ("user", "follows", "user"): ([0, 1], [1, 2]), |
| 470 | ("user", "plays", "game"): ([0, 1], [0, 0]), |
| 471 | ("user", "follows", "game"): ([0, 1], [1, 4]), |
| 472 | }, |
| 473 | num_nodes_dict={"user": 4, "game": 6}, |
| 474 | idtype=idtype, |
| 475 | device=F.ctx(), |
| 476 | ) |
| 477 | g3 = dgl.heterograph( |
| 478 | { |
| 479 | ("user", "follows", "user"): ([0], [2]), |
| 480 | ("user", "plays", "game"): ([1, 2], [3, 4]), |
| 481 | ("user", "follows", "game"): ([], []), |
| 482 | }, |
| 483 | idtype=idtype, |
| 484 | device=F.ctx(), |
| 485 | ) |
| 486 | g_list = [g1, g2, g3] |
| 487 | bg = dgl.batch(g_list) |
| 488 | bg.nodes["user"].data["h1"] = F.randn((bg.num_nodes("user"), 2)) |
| 489 | bg.nodes["user"].data["h2"] = F.randn((bg.num_nodes("user"), 5)) |
| 490 | bg.edges[("user", "follows", "user")].data["h1"] = F.randn( |
| 491 | (bg.num_edges(("user", "follows", "user")), 2) |
| 492 | ) |
| 493 | for fmat in ["coo", "csr", "csc"]: |
| 494 | bg = bg.formats(fmat) |
| 495 | for i in range(len(g_list)): |
| 496 | g_i = g_list[i] |
| 497 | g_slice = dgl.slice_batch(bg, i) |
| 498 | assert g_i.ntypes == g_slice.ntypes |
| 499 | assert g_i.canonical_etypes == g_slice.canonical_etypes |
| 500 | assert g_i.idtype == g_slice.idtype |
| 501 | assert g_i.device == g_slice.device |
| 502 | for nty in g_i.ntypes: |
| 503 | assert g_i.num_nodes(nty) == g_slice.num_nodes(nty) |
| 504 | for feat in g_i.nodes[nty].data: |
| 505 | assert F.allclose( |
| 506 | g_i.nodes[nty].data[feat], g_slice.nodes[nty].data[feat] |
| 507 | ) |
| 508 | |
| 509 | for ety in g_i.canonical_etypes: |
| 510 | assert g_i.num_edges(ety) == g_slice.num_edges(ety) |
| 511 | for feat in g_i.edges[ety].data: |
| 512 | assert F.allclose( |
| 513 | g_i.edges[ety].data[feat], g_slice.edges[ety].data[feat] |
| 514 | ) |