Parse the ONNX graph and create the corresponding TensorRT network definition. :param onnx_path: The path to the ONNX graph to load.
(self, onnx_path)
| 195 | self.parser = None |
| 196 | |
| 197 | def create_network(self, onnx_path): |
| 198 | """ |
| 199 | Parse the ONNX graph and create the corresponding TensorRT network definition. |
| 200 | :param onnx_path: The path to the ONNX graph to load. |
| 201 | """ |
| 202 | network_flags = (1 << int(trt.NetworkDefinitionCreationFlag.EXPLICIT_BATCH)) |
| 203 | |
| 204 | self.network = self.builder.create_network(network_flags) |
| 205 | self.parser = trt.OnnxParser(self.network, self.trt_logger) |
| 206 | |
| 207 | onnx_path = os.path.realpath(onnx_path) |
| 208 | with open(onnx_path, "rb") as f: |
| 209 | if not self.parser.parse(f.read()): |
| 210 | log.error("Failed to load ONNX file: {}".format(onnx_path)) |
| 211 | for error in range(self.parser.num_errors): |
| 212 | log.error(self.parser.get_error(error)) |
| 213 | sys.exit(1) |
| 214 | |
| 215 | inputs = [self.network.get_input(i) for i in range(self.network.num_inputs)] |
| 216 | outputs = [self.network.get_output(i) for i in range(self.network.num_outputs)] |
| 217 | |
| 218 | log.info("Network Description") |
| 219 | for input in inputs: |
| 220 | self.batch_size = input.shape[0] |
| 221 | log.info("Input '{}' with shape {} and dtype {}".format(input.name, input.shape, input.dtype)) |
| 222 | for output in outputs: |
| 223 | log.info("Output '{}' with shape {} and dtype {}".format(output.name, output.shape, output.dtype)) |
| 224 | assert self.batch_size > 0 |
| 225 | self.builder.max_batch_size = self.batch_size |
| 226 | |
| 227 | def create_engine(self, engine_path, precision, calib_input=None, calib_cache=None, calib_num_images=25000, |
| 228 | calib_batch_size=8, calib_preprocessor=None): |
no test coverage detected