| 2554 | |
| 2555 | @parametrize_idtype |
| 2556 | def test_module_add_reverse(idtype): |
| 2557 | transform = dgl.AddReverse() |
| 2558 | |
| 2559 | # Case1: Add reverse edges for a homogeneous graph |
| 2560 | g = dgl.graph(([0], [1]), idtype=idtype, device=F.ctx()) |
| 2561 | g.ndata["h"] = F.randn((g.num_nodes(), 3)) |
| 2562 | g.edata["w"] = F.randn((g.num_edges(), 2)) |
| 2563 | new_g = transform(g) |
| 2564 | assert new_g.device == g.device |
| 2565 | assert new_g.idtype == g.idtype |
| 2566 | assert g.num_nodes() == new_g.num_nodes() |
| 2567 | src, dst = new_g.edges() |
| 2568 | eset = set(zip(list(F.asnumpy(src)), list(F.asnumpy(dst)))) |
| 2569 | assert eset == {(0, 1), (1, 0)} |
| 2570 | assert F.allclose(g.ndata["h"], new_g.ndata["h"]) |
| 2571 | assert F.allclose(g.edata["w"], F.narrow_row(new_g.edata["w"], 0, 1)) |
| 2572 | assert F.allclose( |
| 2573 | F.narrow_row(new_g.edata["w"], 1, 2), |
| 2574 | F.zeros((1, 2), F.float32, F.ctx()), |
| 2575 | ) |
| 2576 | |
| 2577 | # Case2: Add reverse edges for a homogeneous graph and copy edata |
| 2578 | transform = dgl.AddReverse(copy_edata=True) |
| 2579 | new_g = transform(g) |
| 2580 | assert new_g.device == g.device |
| 2581 | assert new_g.idtype == g.idtype |
| 2582 | assert g.num_nodes() == new_g.num_nodes() |
| 2583 | src, dst = new_g.edges() |
| 2584 | eset = set(zip(list(F.asnumpy(src)), list(F.asnumpy(dst)))) |
| 2585 | assert eset == {(0, 1), (1, 0)} |
| 2586 | assert F.allclose(g.ndata["h"], new_g.ndata["h"]) |
| 2587 | assert F.allclose(g.edata["w"], F.narrow_row(new_g.edata["w"], 0, 1)) |
| 2588 | assert F.allclose(g.edata["w"], F.narrow_row(new_g.edata["w"], 1, 2)) |
| 2589 | |
| 2590 | # Case3: Add reverse edges for a heterogeneous graph |
| 2591 | g = dgl.heterograph( |
| 2592 | { |
| 2593 | ("user", "plays", "game"): ([0, 1], [1, 1]), |
| 2594 | ("user", "follows", "user"): ([1, 2], [2, 2]), |
| 2595 | }, |
| 2596 | device=F.ctx(), |
| 2597 | ) |
| 2598 | new_g = transform(g) |
| 2599 | assert new_g.device == g.device |
| 2600 | assert new_g.idtype == g.idtype |
| 2601 | assert g.ntypes == new_g.ntypes |
| 2602 | assert set(new_g.canonical_etypes) == { |
| 2603 | ("user", "plays", "game"), |
| 2604 | ("user", "follows", "user"), |
| 2605 | ("game", "rev_plays", "user"), |
| 2606 | } |
| 2607 | for nty in g.ntypes: |
| 2608 | assert g.num_nodes(nty) == new_g.num_nodes(nty) |
| 2609 | |
| 2610 | src, dst = new_g.edges(etype="plays") |
| 2611 | eset = set(zip(list(F.asnumpy(src)), list(F.asnumpy(dst)))) |
| 2612 | assert eset == {(0, 1), (1, 1)} |
| 2613 | |