()
| 25 | |
| 26 | @clear_cache_before_run() |
| 27 | def test_graph_manipulation(): |
| 28 | model = MLP(4) |
| 29 | tracer = ColoTracer() |
| 30 | graph = tracer.trace(model) |
| 31 | nodes = list(graph.nodes) |
| 32 | x, l1, l2, l3, l4, l5, output = nodes |
| 33 | |
| 34 | leaf_nodes = set(get_leaf(graph)) |
| 35 | top_nodes = set(get_top(graph)) |
| 36 | compare_dict = {x: None, l1: 0, l2: 0, l3: 1, l4: 1, l5: 2, output: None} |
| 37 | assign_bfs_level_to_nodes(graph) |
| 38 | |
| 39 | assert leaf_nodes == set([l4, l5]) |
| 40 | assert top_nodes == set([l1, l2]) |
| 41 | for node in graph.nodes: |
| 42 | if node.op in ("placeholder", "output"): |
| 43 | assert not hasattr(node, "bfs_level") |
| 44 | else: |
| 45 | assert node.bfs_level == compare_dict[node] |
| 46 | |
| 47 | |
| 48 | if __name__ == "__main__": |
no test coverage detected
searching dependent graphs…