(self, model=None, args=None, input_names=None,
transforms="default", framework_transforms="default",
meaningful_ids=False)
| 164 | """Tracks nodes and edges of a directed graph and supports basic operations on them.""" |
| 165 | |
| 166 | def __init__(self, model=None, args=None, input_names=None, |
| 167 | transforms="default", framework_transforms="default", |
| 168 | meaningful_ids=False): |
| 169 | self.nodes = {} |
| 170 | self.edges = [] |
| 171 | self.meaningful_ids = meaningful_ids # TODO |
| 172 | self.theme = THEMES["basic"] |
| 173 | |
| 174 | if model: |
| 175 | # Detect framwork |
| 176 | framework = detect_framework(model) |
| 177 | if framework == "torch": |
| 178 | from .pytorch_builder import import_graph, FRAMEWORK_TRANSFORMS |
| 179 | assert args is not None, "Argument args must be provided for Pytorch models." |
| 180 | import_graph(self, model, args) |
| 181 | elif framework == "tensorflow": |
| 182 | from .tf_builder import import_graph, FRAMEWORK_TRANSFORMS |
| 183 | import_graph(self, model) |
| 184 | |
| 185 | # Apply Transforms |
| 186 | if framework_transforms: |
| 187 | if framework_transforms == "default": |
| 188 | framework_transforms = FRAMEWORK_TRANSFORMS |
| 189 | for t in framework_transforms: |
| 190 | t.apply(self) |
| 191 | if transforms: |
| 192 | if transforms == "default": |
| 193 | from .transforms import SIMPLICITY_TRANSFORMS |
| 194 | transforms = SIMPLICITY_TRANSFORMS |
| 195 | for t in transforms: |
| 196 | t.apply(self) |
| 197 | |
| 198 | |
| 199 | def id(self, node): |
nothing calls this directly
no test coverage detected