MCPcopy Index your code
hub / github.com/waleedka/hiddenlayer / Graph

Class Graph

hiddenlayer/graph.py:163–365  ·  view source on GitHub ↗

Tracks nodes and edges of a directed graph and supports basic operations on them.

Source from the content-addressed store, hash-verified

161
162
163class Graph():
164 """Tracks nodes and edges of a directed graph and supports basic operations on them."""
165
166 def __init__(self, model=None, args=None, input_names=None,
167 transforms="default", framework_transforms="default",
168 meaningful_ids=False):
169 self.nodes = {}
170 self.edges = []
171 self.meaningful_ids = meaningful_ids # TODO
172 self.theme = THEMES["basic"]
173
174 if model:
175 # Detect framwork
176 framework = detect_framework(model)
177 if framework == "torch":
178 from .pytorch_builder import import_graph, FRAMEWORK_TRANSFORMS
179 assert args is not None, "Argument args must be provided for Pytorch models."
180 import_graph(self, model, args)
181 elif framework == "tensorflow":
182 from .tf_builder import import_graph, FRAMEWORK_TRANSFORMS
183 import_graph(self, model)
184
185 # Apply Transforms
186 if framework_transforms:
187 if framework_transforms == "default":
188 framework_transforms = FRAMEWORK_TRANSFORMS
189 for t in framework_transforms:
190 t.apply(self)
191 if transforms:
192 if transforms == "default":
193 from .transforms import SIMPLICITY_TRANSFORMS
194 transforms = SIMPLICITY_TRANSFORMS
195 for t in transforms:
196 t.apply(self)
197
198
199 def id(self, node):
200 """Returns a unique node identifier. If the node has an id
201 attribute (preferred), it's used. Otherwise, the hash() is returned."""
202 return node.id if hasattr(node, "id") else hash(node)
203
204 def add_node(self, node):
205 id = self.id(node)
206 # assert(id not in self.nodes)
207 self.nodes[id] = node
208
209 def add_edge(self, node1, node2, label=None):
210 # If the edge is already present, don't add it again.
211 # TODO: If an edge exists with a different label, still don't add it again.
212 edge = (self.id(node1), self.id(node2), label)
213 if edge not in self.edges:
214 self.edges.append(edge)
215
216 def add_edge_by_id(self, vid1, vid2, label=None):
217 self.edges.append((vid1, vid2, label))
218
219 def outgoing(self, node):
220 """Returns nodes connecting out of the given node (or list of nodes)."""

Callers 1

build_graphFunction · 0.85

Calls

no outgoing calls

Tested by

no test coverage detected