(deploy_file, model_file, json_file, engine_file)
| 53 | |
| 54 | |
| 55 | def build_int8_engine(deploy_file, model_file, json_file, engine_file): |
| 56 | builder = trt.Builder(TRT_LOGGER) |
| 57 | network = builder.create_network(EXPLICIT_BATCH) |
| 58 | |
| 59 | config = builder.create_builder_config() |
| 60 | |
| 61 | # # If it is a dynamic onnx model , you need to add the following. |
| 62 | # # profile = builder.create_optimization_profile() |
| 63 | # # profile.set_shape("input_name", (batch, channels, min_h, min_w), (batch, channels, opt_h, opt_w), (batch, channels, max_h, max_w)) |
| 64 | # # config.add_optimization_profile(profile) |
| 65 | |
| 66 | parser = trt.CaffeParser() |
| 67 | config.max_workspace_size = GiB(1) |
| 68 | |
| 69 | if not os.path.exists(deploy_file): |
| 70 | quit('deploy_file file {} not found'.format(deploy_file)) |
| 71 | if not os.path.exists(model_file): |
| 72 | quit('model_file file {} not found'.format(model_file)) |
| 73 | |
| 74 | model_tensors = parser.parse(deploy=deploy_file, model=model_file, network=network, dtype=ModelData.DTYPE) |
| 75 | |
| 76 | for act_name in ModelData.OUTPUT_NAME: |
| 77 | network.mark_output(model_tensors.find(act_name)) |
| 78 | |
| 79 | config.set_flag(trt.BuilderFlag.INT8) |
| 80 | |
| 81 | setDynamicRange(network, json_file) |
| 82 | |
| 83 | engine = builder.build_engine(network, config) |
| 84 | |
| 85 | with open(engine_file, "wb") as f: |
| 86 | f.write(engine.serialize()) |
| 87 | |
| 88 | |
| 89 | if __name__ == '__main__': |
no test coverage detected