MCPcopy
hub / github.com/dmlc/dgl / test_in_subgraph

Function test_in_subgraph

tests/python/common/test_subgraph.py:435–519  ·  view source on GitHub ↗
(idtype)

Source from the content-addressed store, hash-verified

433
434@parametrize_idtype
435def test_in_subgraph(idtype):
436 hg = dgl.heterograph(
437 {
438 ("user", "follow", "user"): (
439 [1, 2, 3, 0, 2, 3, 0],
440 [0, 0, 0, 1, 1, 1, 2],
441 ),
442 ("user", "play", "game"): ([0, 0, 1, 3], [0, 1, 2, 2]),
443 ("game", "liked-by", "user"): (
444 [2, 2, 2, 1, 1, 0],
445 [0, 1, 2, 0, 3, 0],
446 ),
447 ("user", "flips", "coin"): ([0, 1, 2, 3], [0, 0, 0, 0]),
448 },
449 idtype=idtype,
450 num_nodes_dict={"user": 5, "game": 10, "coin": 8},
451 ).to(F.ctx())
452 subg = dgl.in_subgraph(hg, {"user": [0, 1], "game": 0})
453 assert subg.idtype == idtype
454 assert len(subg.ntypes) == 3
455 assert len(subg.etypes) == 4
456 u, v = subg["follow"].edges()
457 edge_set = set(zip(list(F.asnumpy(u)), list(F.asnumpy(v))))
458 assert F.array_equal(
459 hg["follow"].edge_ids(u, v), subg["follow"].edata[dgl.EID]
460 )
461 assert edge_set == {(1, 0), (2, 0), (3, 0), (0, 1), (2, 1), (3, 1)}
462 u, v = subg["play"].edges()
463 edge_set = set(zip(list(F.asnumpy(u)), list(F.asnumpy(v))))
464 assert F.array_equal(hg["play"].edge_ids(u, v), subg["play"].edata[dgl.EID])
465 assert edge_set == {(0, 0)}
466 u, v = subg["liked-by"].edges()
467 edge_set = set(zip(list(F.asnumpy(u)), list(F.asnumpy(v))))
468 assert F.array_equal(
469 hg["liked-by"].edge_ids(u, v), subg["liked-by"].edata[dgl.EID]
470 )
471 assert edge_set == {(2, 0), (2, 1), (1, 0), (0, 0)}
472 assert subg["flips"].num_edges() == 0
473 for ntype in subg.ntypes:
474 assert dgl.NID not in subg.nodes[ntype].data
475
476 # Test store_ids
477 subg = dgl.in_subgraph(hg, {"user": [0, 1], "game": 0}, store_ids=False)
478 for etype in ["follow", "play", "liked-by"]:
479 assert dgl.EID not in subg.edges[etype].data
480 for ntype in subg.ntypes:
481 assert dgl.NID not in subg.nodes[ntype].data
482
483 # Test relabel nodes
484 subg = dgl.in_subgraph(hg, {"user": [0, 1], "game": 0}, relabel_nodes=True)
485 assert subg.idtype == idtype
486 assert len(subg.ntypes) == 3
487 assert len(subg.etypes) == 4
488
489 u, v = subg["follow"].edges()
490 old_u = F.gather_row(subg.nodes["user"].data[dgl.NID], u)
491 old_v = F.gather_row(subg.nodes["user"].data[dgl.NID], v)
492 assert F.array_equal(

Callers

nothing calls this directly

Calls 8

in_subgraphMethod · 0.80
asnumpyMethod · 0.80
toMethod · 0.45
ctxMethod · 0.45
edgesMethod · 0.45
edge_idsMethod · 0.45
num_edgesMethod · 0.45
num_nodesMethod · 0.45

Tested by

no test coverage detected