Replace nodes with node. Edges incoming to nodes[0] are connected to the new node, and nodes outgoing from nodes[-1] become outgoing from the new node.
(self, nodes, node)
| 262 | del self.nodes[k] |
| 263 | |
| 264 | def replace(self, nodes, node): |
| 265 | """Replace nodes with node. Edges incoming to nodes[0] are connected to |
| 266 | the new node, and nodes outgoing from nodes[-1] become outgoing from |
| 267 | the new node.""" |
| 268 | nodes = nodes if isinstance(nodes, list) else [nodes] |
| 269 | # Is the new node part of the replace nodes (i.e. want to collapse |
| 270 | # a group of nodes into one of them)? |
| 271 | collapse = self.id(node) in self.nodes |
| 272 | # Add new node and edges |
| 273 | if not collapse: |
| 274 | self.add_node(node) |
| 275 | for in_node in self.incoming(nodes): |
| 276 | # TODO: check specifically for output_shape is not generic. Consider refactoring. |
| 277 | self.add_edge(in_node, node, in_node.output_shape if hasattr(in_node, "output_shape") else None) |
| 278 | for out_node in self.outgoing(nodes): |
| 279 | self.add_edge(node, out_node, node.output_shape if hasattr(node, "output_shape") else None) |
| 280 | # Remove the old nodes |
| 281 | for n in nodes: |
| 282 | if collapse and n == node: |
| 283 | continue |
| 284 | self.remove(n) |
| 285 | |
| 286 | def search(self, pattern): |
| 287 | """Searches the graph for a sub-graph that matches the given pattern |