| 2698 | |
| 2699 | @parametrize_idtype |
| 2700 | def test_module_line_graph(idtype): |
| 2701 | transform = dgl.LineGraph() |
| 2702 | g = dgl.graph(([0, 1, 1], [1, 0, 2]), idtype=idtype, device=F.ctx()) |
| 2703 | g.ndata["h"] = F.tensor([[0.0], [1.0], [2.0]]) |
| 2704 | g.edata["w"] = F.tensor([[0.0], [0.1], [0.2]]) |
| 2705 | new_g = transform(g) |
| 2706 | assert new_g.device == g.device |
| 2707 | assert new_g.idtype == g.idtype |
| 2708 | assert new_g.num_nodes() == g.num_edges() |
| 2709 | src, dst = new_g.edges() |
| 2710 | eset = set(zip(list(F.asnumpy(src)), list(F.asnumpy(dst)))) |
| 2711 | assert eset == {(0, 1), (0, 2), (1, 0)} |
| 2712 | |
| 2713 | transform = dgl.LineGraph(backtracking=False) |
| 2714 | new_g = transform(g) |
| 2715 | assert new_g.device == g.device |
| 2716 | assert new_g.idtype == g.idtype |
| 2717 | assert new_g.num_nodes() == g.num_edges() |
| 2718 | src, dst = new_g.edges() |
| 2719 | eset = set(zip(list(F.asnumpy(src)), list(F.asnumpy(dst)))) |
| 2720 | assert eset == {(0, 2)} |
| 2721 | |
| 2722 | |
| 2723 | @parametrize_idtype |