MCPcopy
hub / github.com/dmlc/dgl / test_slice_batch

Function test_slice_batch

tests/python/common/test_batch-heterograph.py:457–514  ·  view source on GitHub ↗
(idtype)

Source from the content-addressed store, hash-verified

455
456@parametrize_idtype
457def test_slice_batch(idtype):
458 g1 = dgl.heterograph(
459 {
460 ("user", "follows", "user"): ([0, 1], [1, 2]),
461 ("user", "plays", "game"): ([], []),
462 ("user", "follows", "game"): ([0, 0], [1, 4]),
463 },
464 idtype=idtype,
465 device=F.ctx(),
466 )
467 g2 = dgl.heterograph(
468 {
469 ("user", "follows", "user"): ([0, 1], [1, 2]),
470 ("user", "plays", "game"): ([0, 1], [0, 0]),
471 ("user", "follows", "game"): ([0, 1], [1, 4]),
472 },
473 num_nodes_dict={"user": 4, "game": 6},
474 idtype=idtype,
475 device=F.ctx(),
476 )
477 g3 = dgl.heterograph(
478 {
479 ("user", "follows", "user"): ([0], [2]),
480 ("user", "plays", "game"): ([1, 2], [3, 4]),
481 ("user", "follows", "game"): ([], []),
482 },
483 idtype=idtype,
484 device=F.ctx(),
485 )
486 g_list = [g1, g2, g3]
487 bg = dgl.batch(g_list)
488 bg.nodes["user"].data["h1"] = F.randn((bg.num_nodes("user"), 2))
489 bg.nodes["user"].data["h2"] = F.randn((bg.num_nodes("user"), 5))
490 bg.edges[("user", "follows", "user")].data["h1"] = F.randn(
491 (bg.num_edges(("user", "follows", "user")), 2)
492 )
493 for fmat in ["coo", "csr", "csc"]:
494 bg = bg.formats(fmat)
495 for i in range(len(g_list)):
496 g_i = g_list[i]
497 g_slice = dgl.slice_batch(bg, i)
498 assert g_i.ntypes == g_slice.ntypes
499 assert g_i.canonical_etypes == g_slice.canonical_etypes
500 assert g_i.idtype == g_slice.idtype
501 assert g_i.device == g_slice.device
502 for nty in g_i.ntypes:
503 assert g_i.num_nodes(nty) == g_slice.num_nodes(nty)
504 for feat in g_i.nodes[nty].data:
505 assert F.allclose(
506 g_i.nodes[nty].data[feat], g_slice.nodes[nty].data[feat]
507 )
508
509 for ety in g_i.canonical_etypes:
510 assert g_i.num_edges(ety) == g_slice.num_edges(ety)
511 for feat in g_i.edges[ety].data:
512 assert F.allclose(
513 g_i.edges[ety].data[feat], g_slice.edges[ety].data[feat]
514 )

Callers

nothing calls this directly

Calls 4

ctxMethod · 0.45
num_nodesMethod · 0.45
num_edgesMethod · 0.45
formatsMethod · 0.45

Tested by

no test coverage detected