(idtype)
| 84 | @unittest.skipIf(F._default_context_str == "gpu", reason="GPU not implemented") |
| 85 | @parametrize_idtype |
| 86 | def test_prop_nodes_topo(idtype): |
| 87 | # bi-directional chain |
| 88 | g = create_graph(idtype) |
| 89 | assert check_fail(dgl.prop_nodes_topo, g) # has loop |
| 90 | |
| 91 | # tree |
| 92 | tree = dgl.graph([]) |
| 93 | tree.add_nodes(5) |
| 94 | tree.add_edges(1, 0) |
| 95 | tree.add_edges(2, 0) |
| 96 | tree.add_edges(3, 2) |
| 97 | tree.add_edges(4, 2) |
| 98 | tree = dgl.graph(tree.edges()) |
| 99 | # init node feature data |
| 100 | tree.ndata["x"] = F.zeros((5, 2)) |
| 101 | # set all leaf nodes to be ones |
| 102 | tree.nodes[[1, 3, 4]].data["x"] = F.ones((3, 2)) |
| 103 | |
| 104 | # Filtering DGLWarning: |
| 105 | # The input graph for the user-defined edge |
| 106 | # function does not contain valid edges |
| 107 | import warnings |
| 108 | |
| 109 | with warnings.catch_warnings(): |
| 110 | warnings.simplefilter("ignore", category=UserWarning) |
| 111 | dgl.prop_nodes_topo( |
| 112 | tree, message_func=mfunc, reduce_func=rfunc, apply_node_func=None |
| 113 | ) |
| 114 | # root node get the sum |
| 115 | assert F.allclose(tree.nodes[0].data["x"], F.tensor([[3.0, 3.0]])) |
| 116 | |
| 117 | |
| 118 | if __name__ == "__main__": |
no test coverage detected