MCPcopy
hub / github.com/OpenPPL/ppq / create_network

Method create_network

ppq/utils/TensorRTUtil.py:125–153  ·  view source on GitHub ↗

Parse the ONNX graph and create the corresponding TensorRT network definition. :param onnx_path: The path to the ONNX graph to load.

(self, onnx_path)

Source from the content-addressed store, hash-verified

123 self.parser = None
124
125 def create_network(self, onnx_path):
126 """
127 Parse the ONNX graph and create the corresponding TensorRT network definition.
128 :param onnx_path: The path to the ONNX graph to load.
129 """
130 network_flags = (1 << int(trt.NetworkDefinitionCreationFlag.EXPLICIT_BATCH))
131
132 self.network = self.builder.create_network(network_flags)
133 self.parser = trt.OnnxParser(self.network, self.trt_logger)
134
135 onnx_path = os.path.realpath(onnx_path)
136 with open(onnx_path, "rb") as f:
137 if not self.parser.parse(f.read()):
138 log.error("Failed to load ONNX file: {}".format(onnx_path))
139 for error in range(self.parser.num_errors):
140 log.error(self.parser.get_error(error))
141 sys.exit(1)
142
143 inputs = [self.network.get_input(i) for i in range(self.network.num_inputs)]
144 outputs = [self.network.get_output(i) for i in range(self.network.num_outputs)]
145
146 log.info("Network Description")
147 for input in inputs:
148 self.batch_size = input.shape[0]
149 log.info("Input '{}' with shape {} and dtype {}".format(input.name, input.shape, input.dtype))
150 for output in outputs:
151 log.info("Output '{}' with shape {} and dtype {}".format(output.name, output.shape, output.dtype))
152 assert self.batch_size > 0
153 self.builder.max_batch_size = self.batch_size
154
155 def create_engine(self, engine_path, precision, calib_input=None, calib_cache=None, calib_num_images=25000,
156 calib_batch_size=8, calib_preprocessor=None):

Callers 3

build_int8_engineFunction · 0.45
build_engineFunction · 0.45
build_engineFunction · 0.45

Calls 3

errorMethod · 0.80
infoMethod · 0.80
parseMethod · 0.45

Tested by

no test coverage detected