Get a graphviz representation of the network. NOTE, the graph might be redundancy since TRT's INetwork won't clean the unused inputs and layers automatically. TODO: add an flag to hide all the removed layers and their output tensors TODO: replace this when T
(self, path: Union[str, Path] = None)
| 437 | yield layer |
| 438 | |
| 439 | def to_dot(self, path: Union[str, Path] = None) -> Optional[str]: |
| 440 | ''' |
| 441 | Get a graphviz representation of the network. |
| 442 | |
| 443 | NOTE, the graph might be redundancy since TRT's INetwork won't clean the unused inputs and layers |
| 444 | automatically. |
| 445 | TODO: add an flag to hide all the removed layers and their output tensors |
| 446 | TODO: replace this when TensorRT provides a better way to get the graph of INetworkDefinition |
| 447 | TODO: a little feature, add blocks in the figure to highlight the subgraphes of Modules |
| 448 | |
| 449 | Parameters: |
| 450 | path: the path to save the graphviz file, if not provided, will return the graphviz source code |
| 451 | ''' |
| 452 | format = 'text' if not path else path.split('.')[-1] |
| 453 | |
| 454 | try: |
| 455 | import graphviz |
| 456 | except ImportError: |
| 457 | logger.error( |
| 458 | "Failed to import graphviz, please install graphviz to enable Network.to_dot()" |
| 459 | ) |
| 460 | return |
| 461 | |
| 462 | dot = graphviz.Digraph( |
| 463 | comment= |
| 464 | f'TensorRT Graph of {self._get_network_hash(lightweight=False)}', |
| 465 | format=format if format != 'text' else None) |
| 466 | |
| 467 | inputs_names = set([x.name for x in self.get_inputs()]) |
| 468 | output_names = set([x.name for x in self.get_outputs()]) |
| 469 | |
| 470 | node_style = dict( |
| 471 | shape='box', |
| 472 | style='rounded,filled,bold', |
| 473 | fontname='Arial', |
| 474 | fillcolor='#ffffff', |
| 475 | color='#303A3A', |
| 476 | width='1.3', |
| 477 | height='0.84', |
| 478 | ) |
| 479 | |
| 480 | hl_node_style = dict( |
| 481 | shape='box', |
| 482 | style='rounded,filled,bold', |
| 483 | fontname='Arial', |
| 484 | fillcolor='lightblue', |
| 485 | color='#303A3A', |
| 486 | width='1.3', |
| 487 | height='0.84', |
| 488 | ) |
| 489 | |
| 490 | state = self._get_graph() |
| 491 | nodes = set() |
| 492 | tensor_to_alias = {} |
| 493 | tensor_id = [0] |
| 494 | |
| 495 | def get_alias(tensor, tensor_id): |
| 496 | if tensor not in tensor_to_alias: |