Parse the ONNX graph and create the corresponding TensorRT network definition. :param onnx_path: The path to the ONNX graph to load.
(self, onnx_path)
| 123 | self.parser = None |
| 124 | |
| 125 | def create_network(self, onnx_path): |
| 126 | """ |
| 127 | Parse the ONNX graph and create the corresponding TensorRT network definition. |
| 128 | :param onnx_path: The path to the ONNX graph to load. |
| 129 | """ |
| 130 | network_flags = (1 << int(trt.NetworkDefinitionCreationFlag.EXPLICIT_BATCH)) |
| 131 | |
| 132 | self.network = self.builder.create_network(network_flags) |
| 133 | self.parser = trt.OnnxParser(self.network, self.trt_logger) |
| 134 | |
| 135 | onnx_path = os.path.realpath(onnx_path) |
| 136 | with open(onnx_path, "rb") as f: |
| 137 | if not self.parser.parse(f.read()): |
| 138 | log.error("Failed to load ONNX file: {}".format(onnx_path)) |
| 139 | for error in range(self.parser.num_errors): |
| 140 | log.error(self.parser.get_error(error)) |
| 141 | sys.exit(1) |
| 142 | |
| 143 | inputs = [self.network.get_input(i) for i in range(self.network.num_inputs)] |
| 144 | outputs = [self.network.get_output(i) for i in range(self.network.num_outputs)] |
| 145 | |
| 146 | log.info("Network Description") |
| 147 | for input in inputs: |
| 148 | self.batch_size = input.shape[0] |
| 149 | log.info("Input '{}' with shape {} and dtype {}".format(input.name, input.shape, input.dtype)) |
| 150 | for output in outputs: |
| 151 | log.info("Output '{}' with shape {} and dtype {}".format(output.name, output.shape, output.dtype)) |
| 152 | assert self.batch_size > 0 |
| 153 | self.builder.max_batch_size = self.batch_size |
| 154 | |
| 155 | def create_engine(self, engine_path, precision, calib_input=None, calib_cache=None, calib_num_images=25000, |
| 156 | calib_batch_size=8, calib_preprocessor=None): |
no test coverage detected