MCPcopy Index your code
hub / github.com/tensorlayer/TensorLayer / _construct_graph

Method _construct_graph

tensorlayer/models/core.py:692–751  ·  view source on GitHub ↗

construct computation graph for static model using LayerNode object

(self)

Source from the content-addressed store, hash-verified

690 raise ValueError('Either a layer name or a layer index should be given.')
691
692 def _construct_graph(self):
693 """construct computation graph for static model using LayerNode object"""
694 all_layers = []
695 node_by_depth = [] # [[node0, node1], [node2, node3], ...]
696
697 input_tensors_list = self.inputs if isinstance(self.inputs, list) else [self.inputs]
698
699 queue_node = Queue()
700
701 # BFS to visit all nodes that should be involved in the computation graph
702 output_tensors_list = self.outputs if isinstance(self.outputs, list) else [self.outputs]
703 output_nodes = [tensor._info[0] for tensor in output_tensors_list]
704
705 visited_node_names = set()
706 for out_node in output_nodes:
707 if out_node.visited:
708 continue
709 queue_node.put(out_node)
710
711 while not queue_node.empty():
712 cur_node = queue_node.get()
713 in_nodes = cur_node.in_nodes
714
715 for node in in_nodes:
716 node.out_nodes.append(cur_node)
717 if not node.visited:
718 queue_node.put(node)
719 node.visited = True
720 if node.name not in visited_node_names:
721 visited_node_names.add(node.name)
722 # else have multiple layers with the same name
723 else:
724 raise ValueError(
725 'Layer name \'%s\' has already been used by another layer. Please change the layer name.'
726 % node.layer.name
727 )
728
729 # construct the computation graph in top-sort order
730 cur_depth = [tensor._info[0] for tensor in input_tensors_list]
731 next_depth = []
732 indegrees = {}
733
734 visited_layer_names = []
735 while not len(cur_depth) == 0:
736 node_by_depth.append(cur_depth)
737 for node in cur_depth:
738 if node.layer.name not in visited_layer_names:
739 all_layers.append(node.layer)
740 visited_layer_names.append(node.layer.name)
741 for out_node in node.out_nodes:
742 if out_node.name not in indegrees.keys():
743 indegrees[out_node.name] = len(out_node.in_nodes)
744 indegrees[out_node.name] -= 1
745 if indegrees[out_node.name] == 0:
746 next_depth.append(out_node)
747
748 cur_depth = next_depth
749 next_depth = []

Callers 1

__init__Method · 0.95

Calls 2

getMethod · 0.80
addMethod · 0.45

Tested by

no test coverage detected