MCPcopy
hub / github.com/dmlc/dgl / test_out_subgraph

Function test_out_subgraph

tests/python/common/test_subgraph.py:523–618  ·  view source on GitHub ↗
(idtype)

Source from the content-addressed store, hash-verified

521
522@parametrize_idtype
523def test_out_subgraph(idtype):
524 hg = dgl.heterograph(
525 {
526 ("user", "follow", "user"): (
527 [1, 2, 3, 0, 2, 3, 0],
528 [0, 0, 0, 1, 1, 1, 2],
529 ),
530 ("user", "play", "game"): ([0, 0, 1, 3], [0, 1, 2, 2]),
531 ("game", "liked-by", "user"): (
532 [2, 2, 2, 1, 1, 0],
533 [0, 1, 2, 0, 3, 0],
534 ),
535 ("user", "flips", "coin"): ([0, 1, 2, 3], [0, 0, 0, 0]),
536 },
537 idtype=idtype,
538 ).to(F.ctx())
539 subg = dgl.out_subgraph(hg, {"user": [0, 1], "game": 0})
540 assert subg.idtype == idtype
541 assert len(subg.ntypes) == 3
542 assert len(subg.etypes) == 4
543 u, v = subg["follow"].edges()
544 edge_set = set(zip(list(F.asnumpy(u)), list(F.asnumpy(v))))
545 assert edge_set == {(1, 0), (0, 1), (0, 2)}
546 assert F.array_equal(
547 hg["follow"].edge_ids(u, v), subg["follow"].edata[dgl.EID]
548 )
549 u, v = subg["play"].edges()
550 edge_set = set(zip(list(F.asnumpy(u)), list(F.asnumpy(v))))
551 assert edge_set == {(0, 0), (0, 1), (1, 2)}
552 assert F.array_equal(hg["play"].edge_ids(u, v), subg["play"].edata[dgl.EID])
553 u, v = subg["liked-by"].edges()
554 edge_set = set(zip(list(F.asnumpy(u)), list(F.asnumpy(v))))
555 assert edge_set == {(0, 0)}
556 assert F.array_equal(
557 hg["liked-by"].edge_ids(u, v), subg["liked-by"].edata[dgl.EID]
558 )
559 u, v = subg["flips"].edges()
560 edge_set = set(zip(list(F.asnumpy(u)), list(F.asnumpy(v))))
561 assert edge_set == {(0, 0), (1, 0)}
562 assert F.array_equal(
563 hg["flips"].edge_ids(u, v), subg["flips"].edata[dgl.EID]
564 )
565 for ntype in subg.ntypes:
566 assert dgl.NID not in subg.nodes[ntype].data
567
568 # Test store_ids
569 subg = dgl.out_subgraph(hg, {"user": [0, 1], "game": 0}, store_ids=False)
570 for etype in subg.canonical_etypes:
571 assert dgl.EID not in subg.edges[etype].data
572 for ntype in subg.ntypes:
573 assert dgl.NID not in subg.nodes[ntype].data
574
575 # Test relabel nodes
576 subg = dgl.out_subgraph(hg, {"user": [1], "game": 0}, relabel_nodes=True)
577 assert subg.idtype == idtype
578 assert len(subg.ntypes) == 3
579 assert len(subg.etypes) == 4
580

Callers

nothing calls this directly

Calls 6

asnumpyMethod · 0.80
toMethod · 0.45
ctxMethod · 0.45
edgesMethod · 0.45
edge_idsMethod · 0.45
num_nodesMethod · 0.45

Tested by

no test coverage detected