MCPcopy Index your code
hub / github.com/waleedka/hiddenlayer / __init__

Method __init__

hiddenlayer/graph.py:166–196  ·  view source on GitHub ↗
(self, model=None, args=None, input_names=None,
                 transforms="default", framework_transforms="default",
                 meaningful_ids=False)

Source from the content-addressed store, hash-verified

164 """Tracks nodes and edges of a directed graph and supports basic operations on them."""
165
166 def __init__(self, model=None, args=None, input_names=None,
167 transforms="default", framework_transforms="default",
168 meaningful_ids=False):
169 self.nodes = {}
170 self.edges = []
171 self.meaningful_ids = meaningful_ids # TODO
172 self.theme = THEMES["basic"]
173
174 if model:
175 # Detect framwork
176 framework = detect_framework(model)
177 if framework == "torch":
178 from .pytorch_builder import import_graph, FRAMEWORK_TRANSFORMS
179 assert args is not None, "Argument args must be provided for Pytorch models."
180 import_graph(self, model, args)
181 elif framework == "tensorflow":
182 from .tf_builder import import_graph, FRAMEWORK_TRANSFORMS
183 import_graph(self, model)
184
185 # Apply Transforms
186 if framework_transforms:
187 if framework_transforms == "default":
188 framework_transforms = FRAMEWORK_TRANSFORMS
189 for t in framework_transforms:
190 t.apply(self)
191 if transforms:
192 if transforms == "default":
193 from .transforms import SIMPLICITY_TRANSFORMS
194 transforms = SIMPLICITY_TRANSFORMS
195 for t in transforms:
196 t.apply(self)
197
198
199 def id(self, node):

Callers

nothing calls this directly

Calls 3

detect_frameworkFunction · 0.85
import_graphFunction · 0.70
applyMethod · 0.45

Tested by

no test coverage detected