| 2656 | ) |
| 2657 | @parametrize_idtype |
| 2658 | def test_module_to_simple(idtype): |
| 2659 | transform = dgl.ToSimple() |
| 2660 | g = dgl.graph(([0, 1, 1], [1, 2, 2]), idtype=idtype, device=F.ctx()) |
| 2661 | g.ndata["h"] = F.randn((g.num_nodes(), 2)) |
| 2662 | g.edata["w"] = F.tensor([[0.1], [0.2], [0.3]]) |
| 2663 | sg = transform(g) |
| 2664 | assert sg.device == g.device |
| 2665 | assert sg.idtype == g.idtype |
| 2666 | assert sg.num_nodes() == g.num_nodes() |
| 2667 | assert sg.num_edges() == 2 |
| 2668 | src, dst = sg.edges() |
| 2669 | eset = set(zip(list(F.asnumpy(src)), list(F.asnumpy(dst)))) |
| 2670 | assert eset == {(0, 1), (1, 2)} |
| 2671 | assert F.allclose(sg.edata["count"], F.tensor([1, 2])) |
| 2672 | assert F.allclose(sg.ndata["h"], g.ndata["h"]) |
| 2673 | |
| 2674 | g = dgl.heterograph( |
| 2675 | { |
| 2676 | ("user", "follows", "user"): ([0, 1, 1], [1, 2, 2]), |
| 2677 | ("user", "plays", "game"): ([0, 1, 0], [1, 1, 1]), |
| 2678 | } |
| 2679 | ) |
| 2680 | sg = transform(g) |
| 2681 | assert sg.device == g.device |
| 2682 | assert sg.idtype == g.idtype |
| 2683 | assert sg.ntypes == g.ntypes |
| 2684 | assert sg.canonical_etypes == g.canonical_etypes |
| 2685 | for nty in sg.ntypes: |
| 2686 | assert sg.num_nodes(nty) == g.num_nodes(nty) |
| 2687 | for ety in sg.canonical_etypes: |
| 2688 | assert sg.num_edges(ety) == 2 |
| 2689 | |
| 2690 | src, dst = sg.edges(etype="follows") |
| 2691 | eset = set(zip(list(F.asnumpy(src)), list(F.asnumpy(dst)))) |
| 2692 | assert eset == {(0, 1), (1, 2)} |
| 2693 | |
| 2694 | src, dst = sg.edges(etype="plays") |
| 2695 | eset = set(zip(list(F.asnumpy(src)), list(F.asnumpy(dst)))) |
| 2696 | assert eset == {(0, 1), (1, 1)} |
| 2697 | |
| 2698 | |
| 2699 | @parametrize_idtype |