MCPcopy Index your code
hub / github.com/NVIDIA/TensorRT-LLM / to_dot

Method to_dot

tensorrt_llm/network.py:439–534  ·  view source on GitHub ↗

Get a graphviz representation of the network. NOTE, the graph might be redundancy since TRT's INetwork won't clean the unused inputs and layers automatically. TODO: add an flag to hide all the removed layers and their output tensors TODO: replace this when T

(self, path: Union[str, Path] = None)

Source from the content-addressed store, hash-verified

437 yield layer
438
439 def to_dot(self, path: Union[str, Path] = None) -> Optional[str]:
440 '''
441 Get a graphviz representation of the network.
442
443 NOTE, the graph might be redundancy since TRT's INetwork won't clean the unused inputs and layers
444 automatically.
445 TODO: add an flag to hide all the removed layers and their output tensors
446 TODO: replace this when TensorRT provides a better way to get the graph of INetworkDefinition
447 TODO: a little feature, add blocks in the figure to highlight the subgraphes of Modules
448
449 Parameters:
450 path: the path to save the graphviz file, if not provided, will return the graphviz source code
451 '''
452 format = 'text' if not path else path.split('.')[-1]
453
454 try:
455 import graphviz
456 except ImportError:
457 logger.error(
458 "Failed to import graphviz, please install graphviz to enable Network.to_dot()"
459 )
460 return
461
462 dot = graphviz.Digraph(
463 comment=
464 f'TensorRT Graph of {self._get_network_hash(lightweight=False)}',
465 format=format if format != 'text' else None)
466
467 inputs_names = set([x.name for x in self.get_inputs()])
468 output_names = set([x.name for x in self.get_outputs()])
469
470 node_style = dict(
471 shape='box',
472 style='rounded,filled,bold',
473 fontname='Arial',
474 fillcolor='#ffffff',
475 color='#303A3A',
476 width='1.3',
477 height='0.84',
478 )
479
480 hl_node_style = dict(
481 shape='box',
482 style='rounded,filled,bold',
483 fontname='Arial',
484 fillcolor='lightblue',
485 color='#303A3A',
486 width='1.3',
487 height='0.84',
488 )
489
490 state = self._get_graph()
491 nodes = set()
492 tensor_to_alias = {}
493 tensor_id = [0]
494
495 def get_alias(tensor, tensor_id):
496 if tensor not in tensor_to_alias:

Callers 2

test_to_dotMethod · 0.80
build_engineFunction · 0.80

Calls 7

_get_network_hashMethod · 0.95
get_inputsMethod · 0.95
get_outputsMethod · 0.95
_get_graphMethod · 0.95
splitMethod · 0.45
errorMethod · 0.45
saveMethod · 0.45

Tested by 1

test_to_dotMethod · 0.64