MCPcopy
hub / github.com/NVIDIA/TensorRT-LLM / RewritePatternManager

Class RewritePatternManager

tensorrt_llm/graph_rewriting.py:247–274  ·  view source on GitHub ↗

Source from the content-addressed store, hash-verified

245
246
247class RewritePatternManager(_PatternManager):
248 def rewrite(self, net: Network, args=None):
249 modified = True
250 # TODO: we can optimize this by asking TRT to expose a graph iterator consistent even after
251 # the graph is modified.
252 while modified:
253 modified = False
254 # Since the graph iterator is hold by the underlying INetwork, we can only rebuild the
255 # graph cache and match the nodes again.
256 for layer in net.get_layers():
257 if layer.is_removed():
258 continue
259 for profit, pattern in sorted(self.patterns.values(), key=lambda x: x[0]):
260 pattern.args = args
261
262 if pattern.root_layer is not None and layer.type not in pattern.root_layer:
263 continue
264 if pattern._separate_match_rewrite:
265 if pattern.match(layer):
266 pattern.rewrite(layer)
267 modified = True
268 else:
269 if pattern.match_and_rewrite(layer):
270 modified = True
271
272 @staticmethod
273 def instance():
274 return _global_rewrite_pattern_manager
275
276
277class AnalysisPatternManager(_PatternManager):

Calls

no outgoing calls