| 245 | |
| 246 | |
| 247 | class RewritePatternManager(_PatternManager): |
| 248 | def rewrite(self, net: Network, args=None): |
| 249 | modified = True |
| 250 | # TODO: we can optimize this by asking TRT to expose a graph iterator consistent even after |
| 251 | # the graph is modified. |
| 252 | while modified: |
| 253 | modified = False |
| 254 | # Since the graph iterator is hold by the underlying INetwork, we can only rebuild the |
| 255 | # graph cache and match the nodes again. |
| 256 | for layer in net.get_layers(): |
| 257 | if layer.is_removed(): |
| 258 | continue |
| 259 | for profit, pattern in sorted(self.patterns.values(), key=lambda x: x[0]): |
| 260 | pattern.args = args |
| 261 | |
| 262 | if pattern.root_layer is not None and layer.type not in pattern.root_layer: |
| 263 | continue |
| 264 | if pattern._separate_match_rewrite: |
| 265 | if pattern.match(layer): |
| 266 | pattern.rewrite(layer) |
| 267 | modified = True |
| 268 | else: |
| 269 | if pattern.match_and_rewrite(layer): |
| 270 | modified = True |
| 271 | |
| 272 | @staticmethod |
| 273 | def instance(): |
| 274 | return _global_rewrite_pattern_manager |
| 275 | |
| 276 | |
| 277 | class AnalysisPatternManager(_PatternManager): |
no outgoing calls