(onnx_file, json_file, engine_file)
| 46 | |
| 47 | |
| 48 | def build_engine(onnx_file, json_file, engine_file): |
| 49 | builder = trt.Builder(TRT_LOGGER) |
| 50 | network = builder.create_network(EXPLICIT_BATCH) |
| 51 | |
| 52 | config = builder.create_builder_config() |
| 53 | |
| 54 | # If it is a dynamic onnx model , you need to add the following. |
| 55 | # profile = builder.create_optimization_profile() |
| 56 | # profile.set_shape("input_name", (batch, channels, min_h, min_w), (batch, channels, opt_h, opt_w), (batch, channels, max_h, max_w)) |
| 57 | # config.add_optimization_profile(profile) |
| 58 | |
| 59 | |
| 60 | parser = trt.OnnxParser(network, TRT_LOGGER) |
| 61 | config.max_workspace_size = GiB(1) |
| 62 | |
| 63 | if not os.path.exists(onnx_file): |
| 64 | quit('ONNX file {} not found'.format(onnx_file)) |
| 65 | |
| 66 | with open(onnx_file, 'rb') as model: |
| 67 | if not parser.parse(model.read()): |
| 68 | print('ERROR: Failed to parse the ONNX file.') |
| 69 | for error in range(parser.num_errors): |
| 70 | print(parser.get_error(error)) |
| 71 | return None |
| 72 | |
| 73 | config.set_flag(trt.BuilderFlag.INT8) |
| 74 | |
| 75 | setDynamicRange(network, json_file) |
| 76 | |
| 77 | engine = builder.build_engine(network, config) |
| 78 | |
| 79 | with open(engine_file, "wb") as f: |
| 80 | f.write(engine.serialize()) |
| 81 | |
| 82 | |
| 83 | if __name__ == '__main__': |
no test coverage detected