MCPcopy
hub / github.com/dmlc/dgl / test_module_add_reverse

Function test_module_add_reverse

tests/python/common/transforms/test_transform.py:2556–2651  ·  view source on GitHub ↗
(idtype)

Source from the content-addressed store, hash-verified

2554
2555@parametrize_idtype
2556def test_module_add_reverse(idtype):
2557 transform = dgl.AddReverse()
2558
2559 # Case1: Add reverse edges for a homogeneous graph
2560 g = dgl.graph(([0], [1]), idtype=idtype, device=F.ctx())
2561 g.ndata["h"] = F.randn((g.num_nodes(), 3))
2562 g.edata["w"] = F.randn((g.num_edges(), 2))
2563 new_g = transform(g)
2564 assert new_g.device == g.device
2565 assert new_g.idtype == g.idtype
2566 assert g.num_nodes() == new_g.num_nodes()
2567 src, dst = new_g.edges()
2568 eset = set(zip(list(F.asnumpy(src)), list(F.asnumpy(dst))))
2569 assert eset == {(0, 1), (1, 0)}
2570 assert F.allclose(g.ndata["h"], new_g.ndata["h"])
2571 assert F.allclose(g.edata["w"], F.narrow_row(new_g.edata["w"], 0, 1))
2572 assert F.allclose(
2573 F.narrow_row(new_g.edata["w"], 1, 2),
2574 F.zeros((1, 2), F.float32, F.ctx()),
2575 )
2576
2577 # Case2: Add reverse edges for a homogeneous graph and copy edata
2578 transform = dgl.AddReverse(copy_edata=True)
2579 new_g = transform(g)
2580 assert new_g.device == g.device
2581 assert new_g.idtype == g.idtype
2582 assert g.num_nodes() == new_g.num_nodes()
2583 src, dst = new_g.edges()
2584 eset = set(zip(list(F.asnumpy(src)), list(F.asnumpy(dst))))
2585 assert eset == {(0, 1), (1, 0)}
2586 assert F.allclose(g.ndata["h"], new_g.ndata["h"])
2587 assert F.allclose(g.edata["w"], F.narrow_row(new_g.edata["w"], 0, 1))
2588 assert F.allclose(g.edata["w"], F.narrow_row(new_g.edata["w"], 1, 2))
2589
2590 # Case3: Add reverse edges for a heterogeneous graph
2591 g = dgl.heterograph(
2592 {
2593 ("user", "plays", "game"): ([0, 1], [1, 1]),
2594 ("user", "follows", "user"): ([1, 2], [2, 2]),
2595 },
2596 device=F.ctx(),
2597 )
2598 new_g = transform(g)
2599 assert new_g.device == g.device
2600 assert new_g.idtype == g.idtype
2601 assert g.ntypes == new_g.ntypes
2602 assert set(new_g.canonical_etypes) == {
2603 ("user", "plays", "game"),
2604 ("user", "follows", "user"),
2605 ("game", "rev_plays", "user"),
2606 }
2607 for nty in g.ntypes:
2608 assert g.num_nodes(nty) == new_g.num_nodes(nty)
2609
2610 src, dst = new_g.edges(etype="plays")
2611 eset = set(zip(list(F.asnumpy(src)), list(F.asnumpy(dst))))
2612 assert eset == {(0, 1), (1, 1)}
2613

Callers

nothing calls this directly

Calls 7

transformFunction · 0.85
asnumpyMethod · 0.80
graphMethod · 0.45
ctxMethod · 0.45
num_nodesMethod · 0.45
num_edgesMethod · 0.45
edgesMethod · 0.45

Tested by

no test coverage detected