(model=None, args=None, input_names=None,
transforms="default", framework_transforms="default")
| 129 | ########################################################################### |
| 130 | |
| 131 | def build_graph(model=None, args=None, input_names=None, |
| 132 | transforms="default", framework_transforms="default"): |
| 133 | # Initialize an empty graph |
| 134 | g = Graph() |
| 135 | |
| 136 | # Detect framwork |
| 137 | framework = detect_framework(model) |
| 138 | if framework == "torch": |
| 139 | from .pytorch_builder import import_graph, FRAMEWORK_TRANSFORMS |
| 140 | assert args is not None, "Argument args must be provided for Pytorch models." |
| 141 | import_graph(g, model, args) |
| 142 | elif framework == "tensorflow": |
| 143 | from .tf_builder import import_graph, FRAMEWORK_TRANSFORMS |
| 144 | import_graph(g, model) |
| 145 | else: |
| 146 | raise ValueError("`model` input param must be a PyTorch, TensorFlow, or Keras-with-TensorFlow-backend model.") |
| 147 | |
| 148 | # Apply Transforms |
| 149 | if framework_transforms: |
| 150 | if framework_transforms == "default": |
| 151 | framework_transforms = FRAMEWORK_TRANSFORMS |
| 152 | for t in framework_transforms: |
| 153 | g = t.apply(g) |
| 154 | if transforms: |
| 155 | if transforms == "default": |
| 156 | from .transforms import SIMPLICITY_TRANSFORMS |
| 157 | transforms = SIMPLICITY_TRANSFORMS |
| 158 | for t in transforms: |
| 159 | g = t.apply(g) |
| 160 | return g |
| 161 | |
| 162 | |
| 163 | class Graph(): |
nothing calls this directly
no test coverage detected