Tracks nodes and edges of a directed graph and supports basic operations on them.
| 161 | |
| 162 | |
| 163 | class Graph(): |
| 164 | """Tracks nodes and edges of a directed graph and supports basic operations on them.""" |
| 165 | |
| 166 | def __init__(self, model=None, args=None, input_names=None, |
| 167 | transforms="default", framework_transforms="default", |
| 168 | meaningful_ids=False): |
| 169 | self.nodes = {} |
| 170 | self.edges = [] |
| 171 | self.meaningful_ids = meaningful_ids # TODO |
| 172 | self.theme = THEMES["basic"] |
| 173 | |
| 174 | if model: |
| 175 | # Detect framwork |
| 176 | framework = detect_framework(model) |
| 177 | if framework == "torch": |
| 178 | from .pytorch_builder import import_graph, FRAMEWORK_TRANSFORMS |
| 179 | assert args is not None, "Argument args must be provided for Pytorch models." |
| 180 | import_graph(self, model, args) |
| 181 | elif framework == "tensorflow": |
| 182 | from .tf_builder import import_graph, FRAMEWORK_TRANSFORMS |
| 183 | import_graph(self, model) |
| 184 | |
| 185 | # Apply Transforms |
| 186 | if framework_transforms: |
| 187 | if framework_transforms == "default": |
| 188 | framework_transforms = FRAMEWORK_TRANSFORMS |
| 189 | for t in framework_transforms: |
| 190 | t.apply(self) |
| 191 | if transforms: |
| 192 | if transforms == "default": |
| 193 | from .transforms import SIMPLICITY_TRANSFORMS |
| 194 | transforms = SIMPLICITY_TRANSFORMS |
| 195 | for t in transforms: |
| 196 | t.apply(self) |
| 197 | |
| 198 | |
| 199 | def id(self, node): |
| 200 | """Returns a unique node identifier. If the node has an id |
| 201 | attribute (preferred), it's used. Otherwise, the hash() is returned.""" |
| 202 | return node.id if hasattr(node, "id") else hash(node) |
| 203 | |
| 204 | def add_node(self, node): |
| 205 | id = self.id(node) |
| 206 | # assert(id not in self.nodes) |
| 207 | self.nodes[id] = node |
| 208 | |
| 209 | def add_edge(self, node1, node2, label=None): |
| 210 | # If the edge is already present, don't add it again. |
| 211 | # TODO: If an edge exists with a different label, still don't add it again. |
| 212 | edge = (self.id(node1), self.id(node2), label) |
| 213 | if edge not in self.edges: |
| 214 | self.edges.append(edge) |
| 215 | |
| 216 | def add_edge_by_id(self, vid1, vid2, label=None): |
| 217 | self.edges.append((vid1, vid2, label)) |
| 218 | |
| 219 | def outgoing(self, node): |
| 220 | """Returns nodes connecting out of the given node (or list of nodes).""" |