MCPcopy
hub / github.com/NVIDIA/TensorRT-LLM / rewrite

Method rewrite

tensorrt_llm/graph_rewriting.py:248–270  ·  view source on GitHub ↗
(self, net: Network, args=None)

Source from the content-addressed store, hash-verified

246
247class RewritePatternManager(_PatternManager):
248 def rewrite(self, net: Network, args=None):
249 modified = True
250 # TODO: we can optimize this by asking TRT to expose a graph iterator consistent even after
251 # the graph is modified.
252 while modified:
253 modified = False
254 # Since the graph iterator is hold by the underlying INetwork, we can only rebuild the
255 # graph cache and match the nodes again.
256 for layer in net.get_layers():
257 if layer.is_removed():
258 continue
259 for profit, pattern in sorted(self.patterns.values(), key=lambda x: x[0]):
260 pattern.args = args
261
262 if pattern.root_layer is not None and layer.type not in pattern.root_layer:
263 continue
264 if pattern._separate_match_rewrite:
265 if pattern.match(layer):
266 pattern.rewrite(layer)
267 modified = True
268 else:
269 if pattern.match_and_rewrite(layer):
270 modified = True
271
272 @staticmethod
273 def instance():

Callers 2

test_pattern_rewriterMethod · 0.95
optimizeFunction · 0.95

Calls 6

get_layersMethod · 0.80
is_removedMethod · 0.80
valuesMethod · 0.45
matchMethod · 0.45
rewriteMethod · 0.45
match_and_rewriteMethod · 0.45

Tested by 1

test_pattern_rewriterMethod · 0.76