construct computation graph for static model using LayerNode object
(self)
| 690 | raise ValueError('Either a layer name or a layer index should be given.') |
| 691 | |
| 692 | def _construct_graph(self): |
| 693 | """construct computation graph for static model using LayerNode object""" |
| 694 | all_layers = [] |
| 695 | node_by_depth = [] # [[node0, node1], [node2, node3], ...] |
| 696 | |
| 697 | input_tensors_list = self.inputs if isinstance(self.inputs, list) else [self.inputs] |
| 698 | |
| 699 | queue_node = Queue() |
| 700 | |
| 701 | # BFS to visit all nodes that should be involved in the computation graph |
| 702 | output_tensors_list = self.outputs if isinstance(self.outputs, list) else [self.outputs] |
| 703 | output_nodes = [tensor._info[0] for tensor in output_tensors_list] |
| 704 | |
| 705 | visited_node_names = set() |
| 706 | for out_node in output_nodes: |
| 707 | if out_node.visited: |
| 708 | continue |
| 709 | queue_node.put(out_node) |
| 710 | |
| 711 | while not queue_node.empty(): |
| 712 | cur_node = queue_node.get() |
| 713 | in_nodes = cur_node.in_nodes |
| 714 | |
| 715 | for node in in_nodes: |
| 716 | node.out_nodes.append(cur_node) |
| 717 | if not node.visited: |
| 718 | queue_node.put(node) |
| 719 | node.visited = True |
| 720 | if node.name not in visited_node_names: |
| 721 | visited_node_names.add(node.name) |
| 722 | # else have multiple layers with the same name |
| 723 | else: |
| 724 | raise ValueError( |
| 725 | 'Layer name \'%s\' has already been used by another layer. Please change the layer name.' |
| 726 | % node.layer.name |
| 727 | ) |
| 728 | |
| 729 | # construct the computation graph in top-sort order |
| 730 | cur_depth = [tensor._info[0] for tensor in input_tensors_list] |
| 731 | next_depth = [] |
| 732 | indegrees = {} |
| 733 | |
| 734 | visited_layer_names = [] |
| 735 | while not len(cur_depth) == 0: |
| 736 | node_by_depth.append(cur_depth) |
| 737 | for node in cur_depth: |
| 738 | if node.layer.name not in visited_layer_names: |
| 739 | all_layers.append(node.layer) |
| 740 | visited_layer_names.append(node.layer.name) |
| 741 | for out_node in node.out_nodes: |
| 742 | if out_node.name not in indegrees.keys(): |
| 743 | indegrees[out_node.name] = len(out_node.in_nodes) |
| 744 | indegrees[out_node.name] -= 1 |
| 745 | if indegrees[out_node.name] == 0: |
| 746 | next_depth.append(out_node) |
| 747 | |
| 748 | cur_depth = next_depth |
| 749 | next_depth = [] |