MCPcopy Index your code
hub / github.com/waleedka/hiddenlayer / build_graph

Function build_graph

hiddenlayer/graph.py:131–160  ·  view source on GitHub ↗
(model=None, args=None, input_names=None,
                transforms="default", framework_transforms="default")

Source from the content-addressed store, hash-verified

129###########################################################################
130
131def build_graph(model=None, args=None, input_names=None,
132 transforms="default", framework_transforms="default"):
133 # Initialize an empty graph
134 g = Graph()
135
136 # Detect framwork
137 framework = detect_framework(model)
138 if framework == "torch":
139 from .pytorch_builder import import_graph, FRAMEWORK_TRANSFORMS
140 assert args is not None, "Argument args must be provided for Pytorch models."
141 import_graph(g, model, args)
142 elif framework == "tensorflow":
143 from .tf_builder import import_graph, FRAMEWORK_TRANSFORMS
144 import_graph(g, model)
145 else:
146 raise ValueError("`model` input param must be a PyTorch, TensorFlow, or Keras-with-TensorFlow-backend model.")
147
148 # Apply Transforms
149 if framework_transforms:
150 if framework_transforms == "default":
151 framework_transforms = FRAMEWORK_TRANSFORMS
152 for t in framework_transforms:
153 g = t.apply(g)
154 if transforms:
155 if transforms == "default":
156 from .transforms import SIMPLICITY_TRANSFORMS
157 transforms = SIMPLICITY_TRANSFORMS
158 for t in transforms:
159 g = t.apply(g)
160 return g
161
162
163class Graph():

Callers

nothing calls this directly

Calls 4

GraphClass · 0.85
detect_frameworkFunction · 0.85
import_graphFunction · 0.70
applyMethod · 0.45

Tested by

no test coverage detected